In [8]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class CustomContrastiveDataset(Dataset):
    def __init__(self, data_dir,base_transform=None,transform1=None, transform2=None):
        self.data_dir = data_dir
        self.base_transform=base_transform
        self.transform_pos = transform1
        self.transform_neg = transform2
        self.image_paths = self._load_image_paths(data_dir)
        self.patient_data = self._group_images_by_patient()

    def _load_image_paths(self, data_dir):
        return [os.path.join(data_dir, img_name) for img_name in os.listdir(data_dir) if img_name.endswith('.png')]
    
    def _group_images_by_patient(self):
        patient_dict = {}
        for img_path in self.image_paths:
            patient_id = self.get_patient_id(img_path)
            if patient_id not in patient_dict:
                patient_dict[patient_id] = []
            patient_dict[patient_id].append(img_path)
        return patient_dict
    
    def get_patient_id(self, img_name):
        return img_name.split('_')[0]
    
    def get_side_view(self, img_name):
        if 'r' in img_name:
            return 'r' 
        elif 'l' in img_name:
            return 'l'
        return None
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("L")
        image_name = os.path.basename(img_path)
        patient_id = self.get_patient_id(image_name)
        side_view = self.get_side_view(image_name)

        if random.random() > 0.5:
            positive_img = img
            if self.transform_pos:
                positive_img = self.transform_pos(img)
                img = self.base_transform(img) 
            return img, positive_img, 1
        else:
            other_patients = [pid for pid in self.patient_data if pid != patient_id]
            if other_patients:
                negative_patient_id = random.choice(other_patients)
                negative_img_path = random.choice(self.patient_data[negative_patient_id])
            else:
                all_images = self.patient_data[patient_id]
                negative_img_path = random.choice([img_path for img_path in all_images if self.get_side_view(img_path) != side_view])
            negative_img = Image.open(negative_img_path).convert("L")
            if self.transform_neg:
                negative_img = self.transform_neg(negative_img)
                img = self.base_transform(img) 
            return img, negative_img, 0

In [9]:
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

transform1 = transforms.Compose([
    base_transform,
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15), # 你可以根据需要调整旋转角度
])

transform2 = transforms.Compose([
    base_transform,
    transforms.RandomRotation(degrees=5),  # 你可以根据需要调整旋转角度
])

In [10]:
dataset = CustomContrastiveDataset(data_dir=r"F:\Dataset\Train_tumor\train_phase2\rgb_all_png\train_extend_DBT_slice_rgb_patch3", base_transform=base_transform,transform1=transform1, transform2=transform2)

In [11]:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveNet(nn.Module):
    def __init__(self):
        super(ContrastiveNet, self).__init__()
        # 特征提取器
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 输入通道为1（灰度图像）
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AdaptiveAvgPool2d((1, 1))  # 全局平均池化
        )
        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128)  # 输出特征维度为128
        )

    def forward(self, x):
        # 提取特征
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        # L2归一化
        x = F.normalize(x, p=2, dim=1)
        return x

In [13]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, feature1, feature2, label):
        distance = F.pairwise_distance(feature1, feature2)
        loss = torch.mean((1 - label) * torch.pow(distance, 2) +  # 正样本对
                          label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))  # 负样本对
        return loss

class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature1, feature2,label):
        similarity_matrix = torch.matmul(feature1, feature2.T) / self.temperature
        pos_similarity = torch.diag(similarity_matrix)
        neg_similarity = torch.logsumexp(similarity_matrix, dim=1) - pos_similarity
        loss = -torch.mean((1 - label) * pos_similarity - label * neg_similarity)
        return loss

In [14]:
import torch.optim as optim
from torch.utils.data import DataLoader

model = ContrastiveNet()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, train_loader, criterion, optimizer, num_epochs=150):
    model.train()
    best_loss = float('inf')
    best_model_state = None
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in train_loader:
            img1, img2, label = batch
            feature1 = model(img1)
            feature2 = model(img2)
            loss = criterion(feature1, feature2,label)  # 如果使用 ContrastiveLoss，需要传入 label
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_epoch_loss:.4f}")
        
        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            best_model_state = model.state_dict()  # 保存最佳模型状态
        else:
            best_model_state = None

        
        if best_model_state is not None:
            torch.save(best_model_state, "best_contrastive_model.pth")
            print(f"Saved best model with loss: {best_loss:.4f}")
# 开始训练
train(model, train_loader, criterion, optimizer, num_epochs=100)


Epoch [1/100], Loss: 0.4663
Saved best model with loss: 0.4663
Epoch [2/100], Loss: 0.4363
Saved best model with loss: 0.4363
Epoch [3/100], Loss: 0.4344
Saved best model with loss: 0.4344
Epoch [4/100], Loss: 0.3069
Saved best model with loss: 0.3069
Epoch [5/100], Loss: 0.2456
Saved best model with loss: 0.2456
Epoch [6/100], Loss: 0.2477
Epoch [7/100], Loss: 0.1894
Saved best model with loss: 0.1894
Epoch [8/100], Loss: 0.1583
Saved best model with loss: 0.1583
Epoch [9/100], Loss: 0.1888
Epoch [10/100], Loss: 0.1783
Epoch [11/100], Loss: 0.1649
Epoch [12/100], Loss: 0.1777
Epoch [13/100], Loss: 0.1974
Epoch [14/100], Loss: 0.1908
Epoch [15/100], Loss: 0.1479
Saved best model with loss: 0.1479
Epoch [16/100], Loss: 0.1584
Epoch [17/100], Loss: 0.1817
Epoch [18/100], Loss: 0.1624
Epoch [19/100], Loss: 0.1513
Epoch [20/100], Loss: 0.1739
Epoch [21/100], Loss: 0.1672
Epoch [22/100], Loss: 0.1441
Saved best model with loss: 0.1441
Epoch [23/100], Loss: 0.1400
Saved best model with loss:

In [15]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset
import cv2

class CustomDataset(Dataset):
    def __init__(self, dataset_path, json_path, transform=None, augment_transform=None, n=2):
        self.dataset_path = dataset_path
        self.transform = transform
        self.augment_transform = augment_transform
        self.n = n

        with open(json_path, 'r') as f:
            self.labels = json.load(f)

        self.image_files = [f for f in os.listdir(dataset_path) if f.endswith('.png')]
    
    def __len__(self):
        return len(self.image_files) * self.n
    
    def _extract_prefix(self, filename):
        parts = filename.split('_')
        if len(parts) >= 3:
            prefix = '_'.join(parts[:3])
            view_char = parts[2]
            return prefix,view_char
        return filename  # 如果不足三个 '_'，返回原文件名
    
    def __getitem__(self, idx):
        original_idx = idx % len(self.image_files)

        img_name = self.image_files[original_idx]

        img_path = os.path.join(self.dataset_path, img_name)

        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # 读取灰度图像

        img_prefix,view_char = self._extract_prefix(img_name)
        if view_char.lower().startswith('r'):
            image = cv2.flip(image, 1)  # 水平翻转图像

        normalized_image = image.astype(np.float32) / 255.0
        normalized_image = (normalized_image * 255).astype(np.uint8)    
        rgb_image = cv2.applyColorMap(normalized_image, cv2.COLORMAP_JET)

        rgb_image = Image.fromarray(cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB))

        label = -1  # 默认标签
        for key in self.labels:
            key_prefix,_ = self._extract_prefix(key)
            if key_prefix == img_prefix:
                label = self.labels[key]
                break
        if idx >= len(self.image_files) and self.augment_transform:
            # 对数据增强的样本应用增强变换
            rgb_image = self.augment_transform(rgb_image)
        elif self.transform:
            # 对原始样本应用默认变换
            rgb_image = self.transform(rgb_image)
        single_channel_image = rgb_image[1].unsqueeze(0)
        return single_channel_image, label

In [16]:
class BinaryClassificationNet(nn.Module):
    def __init__(self, feature_extractor, num_classes=1, hidden_dim=64):
        super(BinaryClassificationNet, self).__init__()
        self.feature_extractor = feature_extractor  # 对比学习的特征提取器
        self.classifier = nn.Sequential(
            nn.Linear(128, hidden_dim),  # 假设特征维度是128
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, num_classes)  # 二分类任务
        )
        # 冻结特征提取器的参数
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        logits = self.classifier(features)
        return logits  # 使用BCEWithLogitsLoss时不需要Sigmoid


In [17]:
contrastive_model = ContrastiveNet()
contrastive_model.load_state_dict(torch.load("best_contrastive_model.pth"))
feature_extractor = contrastive_model.feature_extractor

  contrastive_model.load_state_dict(torch.load("best_contrastive_model.pth"))


In [18]:
binary_model = BinaryClassificationNet(feature_extractor)

In [19]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(binary_model.parameters(), lr=0.001)

In [20]:
dataset_path = r"F:\Dataset\Test\test_phase2\rgb_all_png\test_extend_DBT_slice_rgb_patch3"
json_path = r"F:\Dataset\Test\results_test.json"


default_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图片大小
    transforms.ToTensor(),          # 转换为Tensor
])

augment_transform = transforms.Compose([
    transforms.Resize((224, 224)),            # 调整图片大小
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(10),           # 随机旋转
    #transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 随机调整亮度和对比度
    transforms.ToTensor(),                   # 转换为Tensor
])

dataset = CustomDataset(
    dataset_path=dataset_path,
    json_path=json_path,
    transform=default_transform,
    augment_transform=augment_transform,
    n=2
)

train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

val_dataset  = r"F:\Dataset\Val\val_phase2\rgb_all_png\val_extend_DBT_slice_rgb_patch3"
val_json = r"F:\Dataset\Val\results_val.json"

val_dataset =CustomDataset(
    dataset_path=val_dataset,
    json_path=val_json,
    transform=default_transform,
    augment_transform=augment_transform,
    n=2
)

val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [21]:
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, average_precision_score, matthews_corrcoef,
                             cohen_kappa_score, confusion_matrix, classification_report)
import numpy as np

def train_model(train_dataloader, val_dataloader, model, criterion, optimizer, num_epochs=10):
    best_auc = 0.0  # 用于保存最高的 AUC 值
    log_file = "training_log.txt"  # 日志文件路径

    # 打开日志文件
    with open(log_file, "w") as f:
        f.write("Epoch\tTrain Loss\tTrain Accuracy\tTrain AUC\tTrain Precision\tTrain Sensitivity\tTrain Specificity\tTrain f1\tVal Accuracy\tVal AUC\tVal Precision\tVal Sensitivity\tVal Specificity\tVal f1\n")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_preds = []
        train_labels = []

        for images, labels in train_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.float()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_preds.extend(outputs.squeeze().detach().cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        train_loss = running_loss / len(train_dataloader)
        train_accuracy = accuracy_score(train_labels, (np.array(train_preds) > 0.5).astype(int))
        train_auc = roc_auc_score(train_labels, train_preds)
        train_precision = precision_score(train_labels, (np.array(train_preds) > 0.5).astype(int))
        train_sensitivity = recall_score(train_labels, (np.array(train_preds) > 0.5).astype(int))
        train_specificity = recall_score(train_labels, (np.array(train_preds) > 0.5).astype(int), pos_label=0)
        train_f1score = f1_score(train_labels, (np.array(train_preds) > 0.5).astype(int))

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for images, labels in val_dataloader:
                images = images.to(device)
                labels = labels.to(device)
                labels = labels.float()
                outputs = model(images)
                val_preds.extend(outputs.squeeze().cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_accuracy = accuracy_score(val_labels, (np.array(val_preds) > 0.5).astype(int))
        val_auc = roc_auc_score(val_labels, val_preds)
        val_precision = precision_score(val_labels, (np.array(val_preds) > 0.5).astype(int))
        val_sensitivity = recall_score(val_labels, (np.array(val_preds) > 0.5).astype(int))
        val_specificity = recall_score(val_labels, (np.array(val_preds) > 0.5).astype(int), pos_label=0)
        val_f1score = f1_score(val_labels, (np.array(val_preds) > 0.5).astype(int))

        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, "
              f"Train Accuracy: {train_accuracy:.4f}, "
              f"Train AUC: {train_auc:.4f}, "
              f"Train Precision: {train_precision:.4f}, "
              f"Val Accuracy: {val_accuracy:.4f}, "
              f"Val AUC: {val_auc:.4f}, "
              f"Val Precision: {val_precision:.4f}")

        # 保存指标到日志文件
        with open(log_file, "a") as f:
            f.write(f"{epoch + 1}\t{train_loss:.4f}\t{train_accuracy:.4f}\t{train_auc:.4f}\t{train_precision:.4f}\t{train_sensitivity:.4f}\t{train_specificity:.4f}\t{train_f1score:.4f}\t"
                    f"{val_accuracy:.4f}\t{val_auc:.4f}\t{val_precision:.4f}\t{val_sensitivity:.4f}\t{val_specificity:.4f}\t{val_f1score:.4f}\n")

        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), "best_model.pth")
            print(f"New best model saved with AUC: {best_auc:.4f}")

model = binary_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BinaryClassificationNet(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (classifier): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [22]:
train_model(train_dataloader, val_dataloader, model, criterion, optimizer, num_epochs=50)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [1/50], Train Loss: 0.6929, Train Accuracy: 0.4762, Train AUC: 0.4154, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6662, Val Precision: 0.0000
New best model saved with AUC: 0.6662


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [2/50], Train Loss: 0.6921, Train Accuracy: 0.4762, Train AUC: 0.4967, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6502, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [3/50], Train Loss: 0.6920, Train Accuracy: 0.4762, Train AUC: 0.4937, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6962, Val Precision: 0.0000
New best model saved with AUC: 0.6962


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [4/50], Train Loss: 0.6913, Train Accuracy: 0.4762, Train AUC: 0.4954, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6857, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [5/50], Train Loss: 0.6909, Train Accuracy: 0.4762, Train AUC: 0.5156, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6655, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [6/50], Train Loss: 0.6907, Train Accuracy: 0.4762, Train AUC: 0.5113, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6674, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [7/50], Train Loss: 0.6903, Train Accuracy: 0.4762, Train AUC: 0.5049, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6676, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [8/50], Train Loss: 0.6894, Train Accuracy: 0.4762, Train AUC: 0.5283, Train Precision: 0.0000, Val Accuracy: 0.5072, Val AUC: 0.6513, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch [9/50], Train Loss: 0.6893, Train Accuracy: 0.4802, Train AUC: 0.5402, Train Precision: 1.0000, Val Accuracy: 0.5072, Val AUC: 0.6595, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch [10/50], Train Loss: 0.6887, Train Accuracy: 0.4921, Train AUC: 0.5278, Train Precision: 1.0000, Val Accuracy: 0.5072, Val AUC: 0.6601, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch [11/50], Train Loss: 0.6891, Train Accuracy: 0.4921, Train AUC: 0.5128, Train Precision: 1.0000, Val Accuracy: 0.5072, Val AUC: 0.6765, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch [12/50], Train Loss: 0.6884, Train Accuracy: 0.4881, Train AUC: 0.5215, Train Precision: 0.8000, Val Accuracy: 0.5072, Val AUC: 0.6571, Val Precision: 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch [13/50], Train Loss: 0.6866, Train Accuracy: 0.4881, Train AUC: 0.5514, Train Precision: 0.8000, Val Accuracy: 0.5072, Val AUC: 0.6620, Val Precision: 0.0000
Epoch [14/50], Train Loss: 0.6871, Train Accuracy: 0.4881, Train AUC: 0.5302, Train Precision: 0.7143, Val Accuracy: 0.5217, Val AUC: 0.6559, Val Precision: 1.0000
Epoch [15/50], Train Loss: 0.6879, Train Accuracy: 0.4881, Train AUC: 0.5450, Train Precision: 0.6667, Val Accuracy: 0.5217, Val AUC: 0.6700, Val Precision: 1.0000
Epoch [16/50], Train Loss: 0.6859, Train Accuracy: 0.4841, Train AUC: 0.5665, Train Precision: 0.6667, Val Accuracy: 0.5217, Val AUC: 0.6622, Val Precision: 1.0000
Epoch [17/50], Train Loss: 0.6854, Train Accuracy: 0.4960, Train AUC: 0.5588, Train Precision: 0.7778, Val Accuracy: 0.5217, Val AUC: 0.6546, Val Precision: 1.0000
Epoch [18/50], Train Loss: 0.6850, Train Accuracy: 0.4921, Train AUC: 0.5683, Train Precision: 0.7500, Val Accuracy: 0.5217, Val AUC: 0.6712, Val Precision: 1.0000
Epoch [19/50], T