In [None]:
import os
# import shutil
# import tempfile
import matplotlib.pyplot as plt
from PIL import Image
import torch
import numpy as np
from sklearn.metrics import classification_report
import cv2
# import tensorflow as tf
# import pandas as pd
# import seaborn as sns
from Densenet import CustomDenseNet121
from monai.apps import download_and_extract
from monai.config import print_config
from monai.metrics import ROCAUCMetric
from monai.transforms import (
    Activations,
    EnsureChannelFirst,
    AsDiscrete,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    ToTensor,
)
from monai.data import Dataset, DataLoader
from monai.utils import set_determinism


In [None]:
device = torch.device("cuda")
device

In [None]:
data_dir = r'C:\Users\Chirag C\vit\docs\potato\Potato_Leaf'
class_names = sorted([x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))])
num_class = len(class_names)
valid_frac, test_frac = 0.1, 0.1
image_files = [[os.path.join(data_dir, class_name, x)
                for x in os.listdir(os.path.join(data_dir, class_name))]
               for class_name in class_names]
image_file_list = []
image_label_list = []
for i, class_name in enumerate(class_names):
    image_file_list.extend(image_files[i])
    image_label_list.extend([i] * len(image_files[i]))
for i in range(len(image_file_list)):
    cat = cv2.imread(image_file_list[i])
    width, height = int(500), int(500)
    resized_cat = cv2.resize(cat, (width, height), interpolation = cv2.INTER_AREA,)
    cv2.imwrite(image_file_list[i], resized_cat)

num_total = len(image_label_list)
image_width, image_height = Image.open(image_file_list[504]).size

print('Total image count:', num_total)
print("Image dimensions:", image_width, "x", image_height)
print("Label names:", class_names)
print("Label counts:", [len(image_files[i]) for i in range(num_class)])

In [None]:
plt.subplots(3, 3, figsize=(8, 8))
for i,k in enumerate(np.random.randint(num_total, size=9)):
    im = Image.open(image_file_list[k])
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    plt.xlabel(class_names[image_label_list[k]])
    plt.imshow(arr, cmap='gray', vmin=0, vmax=255)
plt.tight_layout()
plt.show()

In [None]:
trainX, trainY, valX, valY, testX, testY = [], [], [], [], [], []
for i in range(len(image_file_list)):
    rann = np.random.random()
    if rann < valid_frac:
        valX.append(image_file_list[i])
        valY.append(image_label_list[i])
    elif rann < test_frac + valid_frac:
        testX.append(image_file_list[i])
        testY.append(image_label_list[i])
    else:
        trainX.append(image_file_list[i])
        trainY.append(image_label_list[i])
print("Training count =", len(trainX), "Validation count =", len(valX), "Test count =", len(testX))

In [None]:
train_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    ScaleIntensity(),
    RandRotate(range_x=15, prob=0.5, keep_size=True),
    RandFlip(spatial_axis=0, prob=0.5),
    RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True),
    ToTensor()
])

val_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    ScaleIntensity(),
    ToTensor()
])

act = Activations(softmax=True)
to_onehot = AsDiscrete(to_onehot=num_class)

In [None]:
model = CustomDenseNet121(
    spatial_dims=2,
    in_channels=3,
    out_channels=num_class
).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
epoch_num = 2
val_interval = 1

In [None]:
print(model)

In [None]:
class Scann(Dataset):
    
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]

train_ds = Scann(trainX, trainY, train_transforms)
train_loader = DataLoader(train_ds, batch_size=10, shuffle=True)

val_ds = Scann(valX, valY, val_transforms)
val_loader = DataLoader(val_ds, batch_size=10)

test_ds = Scann(testX, testY, val_transforms)
test_loader = DataLoader(test_ds, batch_size=10)

In [None]:
def train_and_evaluate_model(parameters):
    model = CustomDenseNet121(
        spatial_dims=2,
        in_channels=3,
        out_channels=num_class
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    for epoch in range(epoch_num):
        model.train()
        for batch_data in train_loader:
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
    model.eval()
    test_accuracy = 0.0
    with torch.no_grad():
        for batch_data in test_loader:
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            outputs = model(inputs)
            predictions = outputs.argmax(dim=1)
            correct_predictions = torch.sum(predictions == labels)
            test_accuracy += correct_predictions.item()
    
    test_accuracy /= len(test_loader.dataset)
    return test_accuracy

In [None]:
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
auc_metric = ROCAUCMetric()
metric_values = list()
for epoch in range(epoch_num):
    print('-' * 10)
    print(f"epoch {epoch + 1}/{epoch_num}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
        epoch_len = len(train_ds) // train_loader.batch_size
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                y = torch.cat([y, val_labels], dim=0)
            y_onehot = [to_onehot(i) for i in y]
            y_pred_act = [act(i) for i in y_pred]
            auc_metric(y_pred_act, y_onehot)
            auc_result = auc_metric.aggregate()
            auc_metric.reset()
            del y_pred_act, y_onehot
            metric_values.append(auc_result)
            acc_value = torch.eq(y_pred.argmax(dim=1), y)
            acc_metric = acc_value.sum().item() / len(acc_value)
            if acc_metric > best_metric:
                best_metric = acc_metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), 'best_metric_model.pth')
                print('saved new best metric model')
            print(f"current epoch: {epoch + 1} current AUC: {auc_result:.4f}"
                  f" current accuracy: {acc_metric:.4f} best AUC: {best_metric:.4f}"
                  f" at epoch: {best_metric_epoch}")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")


In [None]:
import torch.nn.functional as F

num_clients = 2 
local_epochs = 2  
communication_rounds = 10  

global_model = CustomDenseNet121(
    spatial_dims=2,
    in_channels=3,
    out_channels=num_class
).to(device)

# Federated learning loop
for comm_round in range(communication_rounds):
    print(f"Communication Round {comm_round + 1}/{communication_rounds}")
    local_models = []  
    
    for client_id in range(num_clients):
        print(f"Training Local Model on Client {client_id + 1}/{num_clients}")
        local_model = CustomDenseNet121(
            spatial_dims=2,
            in_channels=3,
            out_channels=num_class
        ).to(device)
        local_model.load_state_dict(global_model.state_dict()) 
        
        for epoch in range(local_epochs):
            local_model.train()
            for batch_data in train_loader:
                inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
                optimizer.zero_grad()
                outputs = local_model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
        
        local_models.append(local_model)
    
    with torch.no_grad():
        for global_param, local_params in zip(global_model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
            global_param.data = torch.stack(local_params).mean(dim=0)
    
    
    global_model.eval()
    test_accuracy = 0.0
    with torch.no_grad():
        for batch_data in test_loader:
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            outputs = global_model(inputs)
            predictions = outputs.argmax(dim=1)
            correct_predictions = torch.sum(predictions == labels)
            test_accuracy += correct_predictions.item()
    
    test_accuracy /= len(test_loader.dataset)
    print(f"Test Accuracy after Communication Round {comm_round + 1}: {test_accuracy:.4f}")
