In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from pathlib import Path
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

  Referenced from: <08E12B12-6183-307E-BDA0-374FA8EBA2C9> /Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/torchvision/image.so
  warn(
  _torch_pytree._register_pytree_node(


In [7]:
import torch.nn.functional as F

class PAN(nn.Module):
    def __init__(self, num_classes=1, pretrained_backbone=True, backbone_weights_path=None):
        super(PAN, self).__init__()
        # Load resnet50 as the backbone
        self.backbone = resnet50(pretrained=False)  
        if pretrained_backbone and backbone_weights_path:
            state_dict = torch.load(backbone_weights_path)
            self.backbone.load_state_dict(state_dict, strict=False)

        # Encoder layers
        self.encoder = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Head with upsampling to match input size
        self.head = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.head(x)
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)  # Upsample to input size
        return x

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from pathlib import Path
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

# 自定义数据集类
class ContrailDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        super().__init__()
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.files = [f for f in self.image_dir.iterdir() if f.is_file() and f.suffix in ['.png', '.jpg', '.jpeg']]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image_path = self.files[idx]
        mask_path = self.mask_dir / image_path.name

        image = Image.open(image_path).convert('RGB')  # 读取为RGB图像
        mask = Image.open(mask_path).convert('L')  # 读取为灰度图像

        if self.transform:
            image = self.transform(image)
            mask = transforms.ToTensor()(mask)

        mask = (mask > 0).float()  # 将掩码二值化
        return image, mask.unsqueeze(0)  # 在channel维度增加一个维度

# 定义Dice损失函数
class DiceLoss(nn.Module):
    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice

# 计算 Precision, Recall, F1 Score, IoU
def calculate_metrics(preds, labels):
    preds = (torch.sigmoid(preds) > 0.5).cpu().numpy().flatten()  # 二值化预测
    labels = labels.cpu().numpy().flatten()

    precision = precision_score(labels, preds, zero_division=1)
    recall = recall_score(labels, preds, zero_division=1)
    f1 = f1_score(labels, preds, zero_division=1)
    iou = jaccard_score(labels, preds, zero_division=1)

    return precision, recall, f1, iou

# 计算平均 IoU（mIoU），对于二分类问题，mIoU 就是 IoU 本身
def calculate_mIoU(preds, labels):
    preds = (torch.sigmoid(preds) > 0.5).cpu().numpy()
    labels = labels.cpu().numpy()

    ious = []
    for i in range(preds.shape[0]):  # 针对每一张图片
        iou = jaccard_score(labels[i].flatten(), preds[i].flatten(), zero_division=1)
        ious.append(iou)
    
    return sum(ious) / len(ious)

def calculate_accuracy(preds, labels):
    # 使用Sigmoid并将输出二值化
    preds = (torch.sigmoid(preds) > 0.5).float()  # 二值化预测
    # 确保预测和标签维度一致
    preds = preds.view(-1)
    labels = labels.view(-1)

    # 计算正确的预测像素总数
    correct = (preds == labels).sum().item()
    # 计算总像素数
    total = labels.size(0)

    # 返回准确率
    acc = correct / total
    return acc

# 训练函数
def train_model(model, train_loader, valid_loader, loss_fn, optimizer, num_epochs, device):
    best_f1 = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_acc = 0

        # Training phase
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            acc = calculate_accuracy(outputs, masks)
            train_acc += acc

        # Validation phase
        model.eval()
        valid_loss = 0
        valid_acc = 0
        valid_precision = 0
        valid_recall = 0
        valid_f1 = 0
        valid_iou = 0
        valid_miou = 0

        with torch.no_grad():
            for images, masks in valid_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)

                loss = loss_fn(outputs, masks)
                valid_loss += loss.item()

                acc = calculate_accuracy(outputs, masks)
                valid_acc += acc

                precision, recall, f1, iou = calculate_metrics(outputs, masks)
                valid_precision += precision
                valid_recall += recall
                valid_f1 += f1
                valid_iou += iou
                valid_miou += calculate_mIoU(outputs, masks)

        # 平均值
        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader)
        avg_valid_loss = valid_loss / len(valid_loader)
        avg_valid_acc = valid_acc / len(valid_loader)
        avg_valid_precision = valid_precision / len(valid_loader)
        avg_valid_recall = valid_recall / len(valid_loader)
        avg_valid_f1 = valid_f1 / len(valid_loader)
        avg_valid_iou = valid_iou / len(valid_loader)
        avg_valid_miou = valid_miou / len(valid_loader)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, '
              f'Val Loss: {avg_valid_loss:.4f}, Val Acc: {avg_valid_acc:.4f}, '
              f'Precision: {avg_valid_precision:.4f}, Recall: {avg_valid_recall:.4f}, '
              f'F1 Score: {avg_valid_f1:.4f}, IoU: {avg_valid_iou:.4f}, mIoU: {avg_valid_miou:.4f}')

        # 保存最好的模型
        if avg_valid_f1 > best_f1:
            best_f1 = avg_valid_f1
            torch.save(model.state_dict(), 'PAN_model.pth')

# 设置参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_dir = '/Users/camus/Desktop/trainX/image'
mask_dir = '/Users/camus/Desktop/trainX/maskimage'

batch_size = 4
num_epochs = 100

# 图像预处理，直接使用256x256的尺寸
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 将图像调整为256x256
    transforms.ToTensor(),
])

# 加载数据集
dataset = ContrailDataset(image_dir, mask_dir, transform=transform)
train_size = int(0.75 * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)

# 指定权重文件的路径
backbone_weights_path = '/Volumes/Vettel/学习/预训练模型/resnet50-0676ba61.pth'

model = PAN(num_classes=1, pretrained_backbone=True, backbone_weights_path=backbone_weights_path).to(device)
loss_fn = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
train_model(model, train_loader, valid_loader, loss_fn, optimizer, num_epochs, device)



Epoch [1/100], Train Loss: 0.8135, Train Acc: 0.1064, Val Loss: 0.8221, Val Acc: 0.0984, Precision: 0.0984, Recall: 1.0000, F1 Score: 0.1779, IoU: 0.0984, mIoU: 0.0984
Epoch [2/100], Train Loss: 0.8179, Train Acc: 0.1017, Val Loss: 0.8221, Val Acc: 0.0984, Precision: 0.0984, Recall: 1.0000, F1 Score: 0.1779, IoU: 0.0984, mIoU: 0.0984
Epoch [3/100], Train Loss: 0.7314, Train Acc: 0.4189, Val Loss: 0.7912, Val Acc: 0.4656, Precision: 0.1230, Recall: 0.7296, F1 Score: 0.2088, IoU: 0.1179, mIoU: 0.1233
Epoch [4/100], Train Loss: 0.3745, Train Acc: 0.9154, Val Loss: 0.3327, Val Acc: 0.9344, Precision: 0.6580, Recall: 0.7139, F1 Score: 0.6675, IoU: 0.5086, mIoU: 0.5049
Epoch [5/100], Train Loss: 0.3107, Train Acc: 0.9344, Val Loss: 0.2908, Val Acc: 0.9359, Precision: 0.6430, Recall: 0.8269, F1 Score: 0.7093, IoU: 0.5572, mIoU: 0.5590
Epoch [6/100], Train Loss: 0.2931, Train Acc: 0.9382, Val Loss: 0.2574, Val Acc: 0.9457, Precision: 0.6987, Recall: 0.8261, F1 Score: 0.7426, IoU: 0.5993, mIoU:

Epoch [50/100], Train Loss: 0.1257, Train Acc: 0.9750, Val Loss: 0.2494, Val Acc: 0.9522, Precision: 0.7597, Recall: 0.7859, F1 Score: 0.7507, IoU: 0.6128, mIoU: 0.6291
Epoch [51/100], Train Loss: 0.1340, Train Acc: 0.9737, Val Loss: 0.2462, Val Acc: 0.9525, Precision: 0.7513, Recall: 0.7929, F1 Score: 0.7538, IoU: 0.6121, mIoU: 0.6149
Epoch [52/100], Train Loss: 0.1259, Train Acc: 0.9750, Val Loss: 0.2479, Val Acc: 0.9535, Precision: 0.7816, Recall: 0.7697, F1 Score: 0.7522, IoU: 0.6183, mIoU: 0.6349
Epoch [53/100], Train Loss: 0.1507, Train Acc: 0.9694, Val Loss: 0.4438, Val Acc: 0.9332, Precision: 0.8345, Recall: 0.4394, F1 Score: 0.5562, IoU: 0.3920, mIoU: 0.4317
Epoch [54/100], Train Loss: 0.2037, Train Acc: 0.9591, Val Loss: 0.2374, Val Acc: 0.9506, Precision: 0.7203, Recall: 0.8436, F1 Score: 0.7626, IoU: 0.6228, mIoU: 0.6384
Epoch [55/100], Train Loss: 0.1589, Train Acc: 0.9673, Val Loss: 0.2302, Val Acc: 0.9529, Precision: 0.7352, Recall: 0.8429, F1 Score: 0.7698, IoU: 0.6314,

Epoch [99/100], Train Loss: 0.1045, Train Acc: 0.9795, Val Loss: 0.2420, Val Acc: 0.9545, Precision: 0.7842, Recall: 0.7700, F1 Score: 0.7580, IoU: 0.6202, mIoU: 0.6363
Epoch [100/100], Train Loss: 0.0997, Train Acc: 0.9802, Val Loss: 0.2306, Val Acc: 0.9565, Precision: 0.7870, Recall: 0.7886, F1 Score: 0.7695, IoU: 0.6361, mIoU: 0.6499
