In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_curve, roc_auc_score

# get training class weight, aff : nff = 2:8
def get_classweight(train_dataset):
    train_nff_count = sum(1 for _, label in train_dataset if label == 0)  # 0 NFF
    train_aff_count = sum(1 for _, label in train_dataset if label == 1)  # 1 AFF
    class_weight_nff = 1 / (2 * (train_nff_count / (train_nff_count + train_aff_count)))
    class_weight_aff = 1 / (2 * (train_aff_count / (train_nff_count + train_aff_count)))
    
    return [class_weight_nff, class_weight_aff] # 0 nff , 1 aff 

# training function
def train_model(train_loader, validation_loader, classweight, num_epochs, lr, step_size, gamma, model_name):
    # Load pre-trained model
    torch.hub.set_dir('/local/data1/honzh073/download/TORCH_PRETRAINED')
    
    if model_name == 'resnet18':
        from torchvision.models import resnet18, ResNet18_Weights
        model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
    
    elif model_name == 'resnet50':
        from torchvision.models import resnet50, ResNet50_Weights
        model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        
    elif model_name == 'resnet101':
        from torchvision.models import resnet101, ResNet101_Weights
        model = models.resnet101(weights=ResNet101_Weights.DEFAULT)
         
    elif model_name == 'resnet152':
        from torchvision.models import resnet152, ResNet152_Weights
        model = models.resnet152(weights=ResNet152_Weights.DEFAULT)
        
    elif model_name == 'densenet161':
        from torchvision.models import densenet161, DenseNet161_Weights
        model = models.densenet161(weights=DenseNet161_Weights.DEFAULT)

    elif model_name == 'vgg19':
        from torchvision.models import vgg19, VGG19_Weights
        model = models.vgg19(weights=VGG19_Weights.DEFAULT)
        
    else:
        raise ValueError("Invalid model type. 'vgg19' 'resnet50' 'resnet101' 'resnet152' or 'densenet161'.")
    
    # freeze all layers except fc
    # for param in model.parameters():
    #     param.requires_grad = False
    
    # fc
    num_class = 2
    
    if model_name == 'densenet161':
        in_features = model.classifier.in_features
        model.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(in_features, num_class)) # dropout
        # model.classifier = nn.Sequential(nn.Linear(in_features, num_class)) # no dropout
    if model_name == 'vgg19':
        in_features = model.classifier[0].in_features
        model.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(in_features, num_class)) # dropout
        # model.classifier = nn.Sequential(nn.Linear(in_features, num_class)) # no dropout
    # fc1 = vgg_model.classifier[0]

    else:
        in_features = model.fc.in_features
        model.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(in_features, num_class)) # dropout
        # model.fc = nn.Sequential(nn.Linear(in_features, num_class)) # no dropout
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # # DataParallel speed up
    # if torch.cuda.device_count() > 1:
    #     print("multiple GPU:", torch.cuda.device_count())
    #     model = nn.DataParallel(model)
    # else:
    #     print("single GPU")
    #     model = model.to(device)
    
    # loss function and learning rate
    criterion = nn.CrossEntropyLoss(weight=torch.Tensor(classweight).to(device))
    # optimizer = optim.Adam(model.parameters(), lr=lr)
    
    optimizer = optim.RMSprop(model.parameters(), lr=lr)

    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    # # scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

        
    # Loss, ACC
    train_losses = []
    validation_losses = []
    train_accuracies = []
    validation_accuracies = []
    
    # select best model
    best_validation_accuracy = 0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        correct_train = 0
        total_train = 0
        
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
        
        train_accuracy = 100 * correct_train / total_train
        
        model.eval()
        correct_validation_aff = 0
        total_validation_aff = 0
        correct_validation = 0
        total_validation = 0
        validation_loss = 0
        
        with torch.no_grad():
            for images, labels in validation_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                _, predicted = torch.max(outputs.data, 1)
                total_validation += labels.size(0)
                correct_validation += (predicted == labels).sum().item()
                validation_loss += loss.item()
                # Calculate accuracy for AFF class
                total_validation_aff += torch.sum(labels == 1).item()
                correct_validation_aff += torch.sum((predicted == 1) & (labels == 1)).item()

        validation_accuracy_aff = 100 * correct_validation_aff / total_validation_aff

        # validation accuracy and loss
        validation_accuracy = 100 * correct_validation / total_validation
        validation_loss /= len(validation_loader)
        
        train_losses.append(loss.item())
        validation_losses.append(validation_loss)
        train_accuracies.append(train_accuracy)
        validation_accuracies.append(validation_accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs}, "
            f"train Loss: {loss.item():.4f}, "
            f"val Loss: {validation_loss:.4f}, "
            f"train ACC: {train_accuracy:.2f}%, "
            f"Val ACC: {validation_accuracy:.2f}%")
        
        # scheduler.step()
        scheduler.step()

        # select by high aff ACC
        # if validation_accuracy_aff > best_accuracy:
        #     best_accuracy = validation_accuracy_aff
        #     best_model = model.module if isinstance(model, nn.DataParallel) else model

        if validation_accuracy > best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            # best_model = model.module  # DataParallel
            best_model = model  # single device
            
    # Plot train/val loss,  accuracy
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(validation_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(validation_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend() 

    plt.show()

    return best_model

def test_model(model, test_dataset, batch_size):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    model.eval()
    
    correct_test = 0
    total_test = 0
    test_loss = 0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        all_predictions = []
        all_labels = []
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
            test_loss += loss.item()

    test_accuracy = 100 * correct_test / total_test
    test_loss /= len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%")
     
    auc_score = roc_auc_score(all_labels, all_predictions)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    
    # Precision、Recall、F1 Score
    class_labels = {0: 'NFF', 1: 'AFF'}

    classification_rep = classification_report(all_labels, all_predictions, target_names=[class_labels[i] for i in range(len(class_labels))])
    print("AUC:", auc_score)
    print("Confusion Matrix:")
    print(conf_matrix)
    print("Classification Report:")
    print(classification_rep)

    plot_roc_curve(all_labels, all_predictions)
    
def plot_roc_curve(all_labels, all_predictions):
    fpr, tpr, thresholds = roc_curve(all_labels, all_predictions)
    roc_auc = roc_auc_score(all_labels, all_predictions)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

def show_image(dataset, num_images=5):
    # Get random indices
    random_indices = np.random.choice(len(dataset), num_images, replace=False)

    # Plot images with truncated names
    plt.figure(figsize=(15, 5))
    for i, idx in enumerate(random_indices):
        image, label = dataset[idx]  # Use the dataset directly
        filename = dataset.data[idx][0]  # Get the filename from dataset's internal data attribute
        truncated_filename = filename.split('/')[-1][:15]  # Extract the last part and truncate to 15 characters
        
        # Print the original filename
        print(f"Image location: {filename}")

        plt.subplot(1, num_images, i + 1)
        plt.title(f"Label: {label}")
        plt.imshow(image[0])  # Assuming single-channel (grayscale) image
        # plt.imshow(image[0], cmap='gray')  # Assuming single-channel (grayscale) image

        plt.axis('off')
    plt.show()
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



In [None]:
# import csv
# import os
# from collections import defaultdict

# def filter_hospital_data(input_csv_path, target_hospital_ids):
#     patient_data = defaultdict(list)

#     with open(input_csv_path, 'r') as csvfile:
#         reader = csv.DictReader(csvfile)
#         for row in reader:
#             if row['HospitalID'] in target_hospital_ids:
#                 patient_id = row['PatientID']
#                 patient_data[patient_id].append(row)

#     hospital_55_data = []
#     for images in patient_data.values():
#         hospital_55_data.extend(images)

#     return hospital_55_data

# # Input and output paths
# input_csv_path = "/local/data1/honzh073/local_repository/FL/code/3_single_hospital/csv_files/image_data.csv"
# output_folder = '/local/data1/honzh073/local_repository/FL/code/3_single_hospital/csv_files/'

# # Single hospital id
# target_hospital_ids = ['55']

# # Get data for hospital 55
# hospital_55_data = filter_hospital_data(input_csv_path, target_hospital_ids)

# # Write hospital 55 data to CSV file
# def write_to_csv(file_path, data):
#     with open(file_path, 'w', newline='') as csvfile:
#         writer = csv.DictWriter(csvfile, fieldnames=data[0].keys())
#         writer.writeheader()
#         writer.writerows(data)

# # Save hospital 55 data to 'hospital55.csv'
# write_to_csv(os.path.join(output_folder, 'hospital55.csv'), hospital_55_data)

# print("Saved hospital55.csv for hospital 55.")


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = item['image_path']
        label = item['label']

        # Load image
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Convert label to numerical representation
        if label == 'NFF':
            label = 0
        elif label == 'AFF':
            label = 1
        else:
            raise ValueError("Invalid label in JSON data.")

        return image, label

# Define data augmentation transforms for training data
train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=30),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.GaussianBlur(kernel_size=3),
    transforms.RandomAdjustSharpness(sharpness_factor=5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load JSON data
with open('/local/data1/honzh073/data/suzuki.json', 'r') as json_file:
    json_data = json.load(json_file)
    train_data = json_data['train']
    val_data = json_data['val']

# Create datasets and data loaders
train_dataset = CustomDataset(train_data, transform=train_transform)
val_dataset = CustomDataset(val_data, transform=test_transform)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# Count 'AFF' and 'NFF' labels in train dataset
train_aff_count = sum(1 for _, label in train_dataset if label == 1)  # 'AFF' label is 1
train_nff_count = sum(1 for _, label in train_dataset if label == 0)  # 'NFF' label is 0

# Count 'AFF' and 'NFF' labels in validation dataset
val_aff_count = sum(1 for _, label in val_dataset if label == 1)  # 'AFF' label is 1
val_nff_count = sum(1 for _, label in val_dataset if label == 0)  # 'NFF' label is 0

print(f"Train dataset: AFF count: {train_aff_count}, NFF count: {train_nff_count}")
print(f"Validation dataset: AFF count: {val_aff_count}, NFF count: {val_nff_count}")
classweight = get_classweight(train_dataset)
print(classweight)

Train dataset: AFF count: 49, NFF count: 123
Validation dataset: AFF count: 26, NFF count: 47
[0.6991869918699186, 1.7551020408163265]


In [4]:
epoch_num = 50
step_size = 10
lr = 1e-4

In [8]:
# resnet101
resnet101 = train_model(train_loader, val_loader, classweight, 
                        num_epochs=epoch_num, lr=lr, step_size=step_size, gamma=0.1, model_name='resnet101')


KeyboardInterrupt: 

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import json
import nvflare
from nvflare.apis import fl

# 定义你的数据集类
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = item['image_path']
        label = item['label']

        # Load image
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Convert label to numerical representation
        if label == 'NFF':
            label = 0
        elif label == 'AFF':
            label = 1
        else:
            raise ValueError("Invalid label in JSON data.")

        return image, label

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载JSON数据到两个不同的本地设备
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")

# 读取suzuki.json文件到cuda:0
with open('/local/data1/honzh073/data/suzuki.json', 'r') as json_file:
    json_data_0 = json.load(json_file)
    train_data_0 = json_data_0['train']
    val_data_0 = json_data_0['val']

# 读取hospital54_18.json文件到cuda:1
with open('/local/data1/honzh073/local_repository/FL/learner_json/hospital54_18.json', 'r') as json_file:
    json_data_1 = json.load(json_file)
    train_data_1 = json_data_1['train']
    val_data_1 = json_data_1['val']

# 创建本地数据集
train_dataset_0 = CustomDataset(train_data_0, transform=transform)
val_dataset_0 = CustomDataset(val_data_0, transform=transform)
train_dataset_1 = CustomDataset(train_data_1, transform=transform)
val_dataset_1 = CustomDataset(val_data_1, transform=transform)

# 在本地设备上创建数据加载器
batch_size = 64
train_loader_0 = DataLoader(train_dataset_0, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader_0 = DataLoader(val_dataset_0, batch_size=batch_size, shuffle=False, pin_memory=True)
train_loader_1 = DataLoader(train_dataset_1, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, pin_memory=True)

# 在每个本地设备上定义模型
model_0 = models.resnet101(pretrained=True).to(device0)
model_1 = models.resnet101(pretrained=True).to(device1)

# 在每个本地设备上定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer_0 = torch.optim.SGD(model_0.parameters(), lr=0.001, momentum=0.9)
optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.001, momentum=0.9)

# 联邦学习的迭代次数
num_epochs = 10

# 开始联邦学习的训练
for epoch in range(num_epochs):
    # 在第一个本地设备上进行训练
    model_0.train()
    for local_batch, (local_x, local_y) in enumerate(train_loader_0):
        local_x, local_y = local_x.to(device0), local_y.to(device0)
        optimizer_0.zero_grad()
        local_outputs = model_0(local_x)
        loss = criterion(local_outputs, local_y)
        loss.backward()
        optimizer_0.step()

    # 在第二个本地设备上进行训练
    model_1.train()
    for local_batch, (local_x, local_y) in enumerate(train_loader_1):
        local_x, local_y = local_x.to(device1), local_y.to(device1)
        optimizer_1.zero_grad()
        local_outputs = model_1(local_x)
        loss = criterion(local_outputs, local_y)
        loss.backward()
        optimizer_1.step()

    # 将本地模型上传到服务器并进行全局模型更新
    fl.fedavg([model_0, model_1])

    # 在验证集上测试全局模型的性能（这一步是在本地设备上进行的）
    model_0.eval()
    model_1.eval()

    correct_0, total_0 = 0, 0
    correct_1, total_1 = 0, 0

    with torch.no_grad():
        for local_x, local_y in val_loader_0:
            local_x, local_y = local_x.to(device0), local_y.to(device0)
            outputs = model_0(local_x)
            _, predicted = torch.max(outputs, 1)
            total_0 += local_y.size(0)
            correct_0 += (predicted == local_y).sum().item()

        for local_x, local_y in val_loader_1:
            local_x, local_y = local_x.to(device1), local_y.to(device1)
            outputs = model_1(local_x)
            _, predicted = torch.max(outputs, 1)
            total_1 += local_y.size(0)
            correct_1 += (predicted == local_y).sum().item()

    accuracy_0 = 100 * correct_0 / total_0
    accuracy_1 = 100 * correct_1 / total_1

    print(f'Epoch {epoch + 1}/{num_epochs}, Validation Accuracy on cuda:0: {accuracy_0:.2f}%, Validation Accuracy on cuda:1: {accuracy_1:.2f}%')

# 关闭本地设备
fl.stop_device(device0)
fl.stop_device(device1)


ImportError: cannot import name 'fl' from 'nvflare.apis' (/local/data1/honzh073/anaconda3/envs/nvflare-env/lib/python3.9/site-packages/nvflare/apis/__init__.py)

In [6]:
print('ResNet 101')
test_dataset = CustomDataset('/local/data1/honzh073/local_repository/FL/code/3_single_hospital/csv_files/hospital55.csv', transform=test_transform)
# For test dataset
test_NFF_count = sum(1 for _, label in test_dataset if label == 0)  # 0 NFF
test_AFF_count = sum(1 for _, label in test_dataset if label == 1)  # 1 AFF
print(f"test AFF: {test_AFF_count}, ratio: {test_AFF_count / (test_AFF_count + test_NFF_count):.2f}")
print(f"---- NFF: {test_NFF_count}, ratio: {test_NFF_count / (test_AFF_count + test_NFF_count):.2f}")

test_model(model=resnet101, test_dataset=test_dataset, batch_size=batch_size)
params_count = count_parameters(resnet101)
print(f"number of parameters: {params_count}")


ResNet 101


TypeError: string indices must be integers

In [None]:
# # resnet152
# resnet152 = train_model(train_loader, val_loader, classweight, 
#                         num_epochs=epoch_num, lr=lr, step_size=step_size, gamma=0.1, model_name='resnet152')


In [None]:
# print('ResNet 152')

# test_model(model=resnet152, test_dataset=test_dataset, batch_size=batch_size)
# params_count = count_parameters(resnet152)
# print(f"number of parameters: {params_count}")


In [None]:
# densenet161

# densenet161 = train_model(train_loader, val_loader, classweight, 
#                           num_epochs=epoch_num, lr=lr, step_size=step_size, gamma=0.1, model_name='densenet161')


In [None]:
# test_model(model=densenet161, test_dataset=test_dataset, batch_size=batch_size)

In [None]:
# # vgg19
# vgg19 = train_model(train_loader, val_loader, classweight, 
#                           num_epochs=epoch_num, lr=lr, step_size=step_size, gamma=0.1, model_name='vgg19')


In [None]:
# test_model(model=vgg19, test_dataset=test_dataset, batch_size=batch_size)