In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from torchvision import models

import matplotlib.pyplot as plt
import numpy as np
import os
from copy import deepcopy
from PIL import Image

import timm
from utils import *
import warnings
warnings.filterwarnings("ignore")
# a = timm.create_model('densenet121', in_chans=3, num_classes=0, pretrained=False)
# print(a)

In [None]:
!nvidia-smi

In [None]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


In [None]:

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from torchvision.transforms import v2
import PIL
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]


size=512

ratio = 1.0
transform_train = v2.Compose([
                      v2.RandomResizedCrop(size=(size, size), scale= (0.87, 1.0), ratio=(0.7, 1.3), interpolation=v2.InterpolationMode.BICUBIC ), #  ratio=(0.7, 1.3),
                    v2.RandomHorizontalFlip(p=0.5),
                    v2.RandomVerticalFlip(p=0.5),
                    v2.RandomApply([v2.RandomRotation(180)]),
                    v2.RandomApply([v2.ColorJitter(0.2,0.2,0.2,0.1)], 0.2), # brightness, contrast, saturation, hue
                    v2.RandomGrayscale(p=0.25),
                    v2.RandomApply([v2.GaussianBlur(kernel_size=7, sigma=0.5)],0.2),
                    v2.ToTensor(),
                    
        ]) 
transform_val = v2.Compose([
                    v2.Resize(int(1.1*size), interpolation=PIL.Image.BICUBIC),
                    v2.CenterCrop(size),
                    v2.ToTensor(),
        ]) 

In [None]:
def AverageModel(models):
    model_avg = models[0]  # Start with the first model's architecture

# Sum the weights of each model
    with torch.no_grad():
        for param_tensor in model_avg.state_dict():
            avg_param = torch.zeros_like(model_avg.state_dict()[param_tensor])
            for model in models:
                avg_param += model.state_dict()[param_tensor]
            avg_param /= len(models)
            model_avg.state_dict()[param_tensor].copy_(avg_param)
    return model_avg

In [None]:
def check_folder(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    return dir
import cv2
def Clahe(image):
    gray_array0 = np.asarray(image)
    gray_array0 = cv2.cvtColor(gray_array0, cv2.COLOR_RGB2GRAY)
    gray_array = gray_array0.astype(np.uint8)

# Create a CLAHE object
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))

# Apply CLAHE to the grayscale image
    clahe_image = clahe.apply(gray_array)

# Convert the NumPy array back to PIL image
    clahe_pil_image = Image.fromarray(clahe_image)
    return clahe_pil_image

class CustomDataset(Dataset):
    def __init__(self, root, vessel,transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform
        self.classes = sorted(os.listdir(root))  # list of class names
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}  # mapping from class name to index
        self.images = []  # list of tuples (image_path, class_index)
        self.normalize_color = v2.Normalize(mean, std)
        # self.normalize_binary = v2.Normalize((0.5,), (0.5,))
        self.n_classes = 5
        # self.additional_aug = v2.RandomChoice([
        #             v2.RandomHorizontalFlip(p=1),
        #             v2.RandomVerticalFlip(p=1),
        #             v2.RandomApply([v2.ColorJitter(0.1,0.1,0.1,0.1)], p=1), # brightness, contrast, saturation, hue
                    
        # ]) 

        for class_name in self.classes:
            class_dir = os.path.join(root, class_name)
            vessel_dir = os.path.join(vessel, class_name)
            for image_name in sorted(os.listdir(class_dir)):

                image_path = os.path.join(class_dir, image_name)
                vessel_path = os.path.join(vessel_dir, image_name)
                self.images.append((image_path, self.class_to_idx[class_name], vessel_path))

    def __len__(self):
        return len(self.images)
    # def add_gaussian_noise(self, tensor, mean=0.0, std=1.0):
    #     noise = torch.randn(tensor.shape) * std + mean
    #     noisy_tensor = tensor + noise
    #     return noisy_tensor
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path, class_index, vessel_path = self.images[idx]
        image = Image.open(img_path)
        image = self.transform(image)
        image  = self.normalize_color(image)
        # vessel = self.transform(vessel)
        # vessel = self.normalize_binary(vessel)

        # image2 = self.add_gaussian_noise(image1_)

        # image2 = self.normalize_color(image2)

        return {"image":image, "label":class_index}

In [None]:
batch_size = 3
img_type = 'DDR'
path = img_type + '_dataset/'
img_dir = '/home/monetai2/Desktop/LabFolder/Pham/datasets/DDR/preprocessed/'
train_data = CustomDataset(root=img_dir+'train/', vessel=img_dir+'train/', transform=transform_train)
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
print(len(train_data))
print(train_data.class_to_idx)

val_data = CustomDataset(root=img_dir+'val/', vessel=img_dir+'val/', transform=transform_val)
valloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
print(len(val_data))
print(val_data.class_to_idx)

checkpoint_path = 'checkpoint/'
check_folder(checkpoint_path)

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
data_iter = iter(trainloader)
data = next(data_iter)

# show images
# imshow(torchvision.utils.make_grid(data["vessel"]))
# print labels
print( data["label"][j] for j in range(batch_size))
imshow(torchvision.utils.make_grid(data["image"]))
print(data["label"][j] for j in range(batch_size))

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:

import timm
import torch
import torch.nn as nn

class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(SqueezeExcitation, self).__init__()
        
        # Squeeze operation: Global average pooling (AdaptiveAvgPool2d outputs a single value per channel)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # Excitation operation: Two fully connected (FC) layers
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio, bias=False)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels, bias=False)
        
        # Activation layers: ReLU followed by Sigmoid
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # Squeeze: Global Average Pooling (reshape output to [batch_size, channels])
        squeeze = self.global_avg_pool(x).view(batch_size, channels)
        
        # Excitation: FC -> ReLU -> FC -> Sigmoid
        excitation = self.fc1(squeeze)
        excitation = self.relu(excitation)
        excitation = self.fc2(excitation)
        excitation = self.sigmoid(excitation)

        # Reshape the excitation weights back to [batch_size, channels, 1, 1]
        excitation = excitation.view(batch_size, channels, 1, 1)
        
        # Scale the input feature maps by the excitation weights (channel-wise multiplication)
        x = x * excitation
        
        return x

class ChannelAttention4x(nn.Module):
    def __init__(self, channels, expand=4):
        super(ChannelAttention4x, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels * expand , 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels * expand, channels * expand , 1, groups=channels * expand, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels * expand, channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        y = x*y
        return y
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)*x
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)*x
class CBAM(nn.Module):
    def __init__(self, in_chans):
        super(CBAM, self).__init__()

        self.ca = ChannelAttention(in_chans)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x        
class MKB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MKB, self).__init__()
        
        # Ensure the dilations tuple is the correct length (three branches)

        # First branch with dilation 1
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(),
                                   SqueezeExcitation(out_channels)                                   
                                  )
        # Second branch with dilation 2
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=5, padding='same'),                                   
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(),
                                   SqueezeExcitation(out_channels)
                                  )
        
        # Third branch with dilation 3
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=7, padding='same'),                                   
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(),
                                   SqueezeExcitation(out_channels)
                                  )

        self.conv1x1_1 = nn.Conv2d(out_channels*3,out_channels, kernel_size=1, bias=False)
                             
                                      
        # self.relu1 = nn.ReLU()
        self.conv1x1_2 = nn.Conv2d(out_channels,out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        # Compute the outputs of each parallel branch
        out1 = self.conv1(x)
        out1 = out1 + x
        
        out2 = self.conv2(x)
        out2 = out2 + x
        
        out3 = self.conv3(x)
        out3 = out3 + x


        # Sum the outputs from the parallel branches
        out = torch.cat((out1,out2,out3),dim=1)
        
        out = self.conv1x1_1(out)
        out = self.conv1x1_2(out)
        out = x + out
        return out


class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.net1 = timm.create_model('coatnet_rmlp_2_rw_224', img_size=size, in_chans=3,  num_classes=5,  pretrained=True)
        self.ca_stem = MKB(128,128)
        self.ca_stage0 = MKB(128,128)
        self.ca_stage1 = MKB(256,256)



    def forward_features(self, x):
        x = self.net1.stem(x)
        x = self.ca_stem(x)

        n = len(self.net1.stages)

        for i in range(n):
            x = self.net1.stages[i](x)
            if i == 0:
 
                x = self.ca_stage0(x)
            elif i == 1:

                x = self.ca_stage1(x)               
        x = self.net1.norm(x)
        return x
                
    def forward(self, x, std=0.0):

        ft = self.forward_features(x)


        y = self.net1.forward_head(ft)

        return y , ft      
network = 'proposed'
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


net = CustomModel()
# net.load_state_dict(torch.load("checkpoint/DDR_seed0_proposed512_kappa_ema.pth"))



net = net.to(device)


In [None]:
class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9998, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
    def update_decay(self, new_decay):
        self.decay = new_decay
    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)
model_ema = ModelEmaV2(net, device=device)

In [None]:
epochs=30
global trn_loss 
global tst_loss 
global trn_acc
global tst_acc

trn_loss = torch.zeros(epochs)
tst_loss= torch.zeros(epochs)
trn_acc = torch.zeros(epochs)
tst_acc = torch.zeros(epochs)


In [None]:

def train(epoch, coeff=0, coeff_ema=0.9996):
    # model_ema.update_decay(coeff_ema)
    # print('ema weight' , coeff_ema)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    with torch.autograd.set_detect_anomaly(True):
        for batch_idx, data in enumerate(trainloader):

            images = data["image"].to(device)
            labels = data["label"].to(device)

            optimizer.zero_grad()
            outputs, f1 = net(images, std)
            _, predicted = outputs.max(1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            if model_ema is not None:
                model_ema.update(net)

            train_loss += loss.item()
        
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return train_loss

In [None]:
def val(epoch):
    net.eval()
    # net_ema.eval()
    
    val_loss = 0
    correct = 0
    total = 0
    pred = []
    ground_truth = []
    pred_ema = []
    with torch.no_grad():
        for batch_idx, data in enumerate(valloader):


            images = data["image"].to(device)
            labels = data["label"].to(device)
            # labels = labels.max(dim=1)[0]
            # grays = transforms.Grayscale()(images).to(device)

            outputs,_ = net(images, std=1.)
            _, predicted = outputs.max(1)
            # _, labels = labels.max(1)
            
            
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            progress_bar(batch_idx, len(valloader), 'Val Loss: %.3f | Val Acc: %.3f%% (%d/%d)' % (val_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
            pred.extend(predicted.data.cpu().numpy().tolist())
            ground_truth.extend(labels.data.cpu().numpy().flatten().tolist())
            
            net_ema = model_ema.module            
            outputs_ema,_ = net_ema(images, std=1.)
            _, predicted_ema = outputs_ema.max(1)
            pred_ema.extend(predicted_ema.data.cpu().numpy().tolist())

    acc = 100.*correct/total

    return val_loss, np.asarray(pred), np.asarray(ground_truth), acc, np.asarray(pred_ema)


In [None]:
# from torch.nn import functional as F
# class FocalLoss(torch.nn.Module):
#     def __init__(self, alpha, gamma=2.0):
#         super(FocalLoss, self).__init__()
#         self.alpha = alpha  # Precomputed alpha values
#         self.gamma = gamma

#     def forward(self, inputs, targets):
#         log_probs = F.log_softmax(inputs, dim=1)  # Log-probabilities
#         probs = torch.exp(log_probs)             # Probabilities
#         log_p = log_probs[range(len(targets)), targets]
#         p = probs[range(len(targets)), targets]

#         focal_term = (1 - p) ** self.gamma
#         loss = -focal_term * log_p

#         # Apply alpha values
#         alpha_t = self.alpha[targets]  # Select alpha for true classes
#         loss = alpha_t * loss

#         return loss.mean()

# class_counts = np.array([3133, 315, 2238, 118, 456])  # Class 0: 100 samples, Class 1: 300 samples
# total_samples = sum(class_counts)
# class_weights = total_samples / class_counts
# print("Class Weights:", class_weights)
# class_weights = class_weights / class_weights.sum()
# print("Class Weights:", class_weights)

# # Convert to a PyTorch tensor
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# criterion = FocalLoss(alpha=class_weights_tensor, gamma=2.0)
criterion_mse = nn.MSELoss()
lr = 1e-5
optimizer = optim.AdamW(net.parameters(), lr=lr, weight_decay = 0.01)
# optimizer = optim.Adam(net.parameters(), lr=1e-4)
# scheduler = StepLRScheduler(optimizer, decay_t=10, decay_rate=0.8)
import torch.optim.lr_scheduler as lr_scheduler
# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# scheduler1 = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
#                                                    max_lr=lr, 
#                                                    total_steps=epochs*len(trainloader))
# scheduler = PlateauLRScheduler(optimizer, decay_rate=0.2, patience_t=3, warmup_t=3, warmup_lr_init=1e-7)
# scheduler = CosineLRScheduler(optimizer, t_initial=epochs, warmup_t=3, warmup_lr_init=1e-7)
best_acc = 0  # best test accuracy
# base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
# opt = torchcontrib.optim.SWA(base_opt)

In [None]:
import time
import os
from sklearn.metrics import roc_curve, recall_score, precision_score, f1_score, cohen_kappa_score, accuracy_score, auc, classification_report, confusion_matrix
from imblearn.metrics import sensitivity_score, specificity_score
# def quadratic_weighted_kappa(conf_mat):
#     assert conf_mat.shape[0] == conf_mat.shape[1]
#     cate_num = conf_mat.shape[0]

#     # Quadratic weighted matrix
#     weighted_matrix = np.zeros((cate_num, cate_num))
#     for i in range(cate_num):
#         for j in range(cate_num):
#             weighted_matrix[i][j] = 1 - float(((i - j)**2) / ((cate_num - 1)**2))

#     # Expected matrix
#     ground_truth_count = np.sum(conf_mat, axis=1)
#     pred_count = np.sum(conf_mat, axis=0)
#     expected_matrix = np.outer(ground_truth_count, pred_count)

#     # Normalization
#     conf_mat = conf_mat / conf_mat.sum()
#     expected_matrix = expected_matrix / expected_matrix.sum()

#     observed = (conf_mat * weighted_matrix).sum()
#     expected = (expected_matrix * weighted_matrix).sum()
#     return (observed - expected) / (1 - expected)
def quadratic_weighted_kappa(conf_mat):
    assert conf_mat.shape[0] == conf_mat.shape[1]
    cate_num = conf_mat.shape[0]

    # Quadratic weighted matrix
    weighted_matrix = np.zeros((cate_num, cate_num))
    for i in range(cate_num):
        for j in range(cate_num):
            weighted_matrix[i][j] = 1 - float(((i - j)**2) / ((cate_num - 1)**2))

    # Expected matrix
    ground_truth_count = np.sum(conf_mat, axis=1)
    pred_count = np.sum(conf_mat, axis=0)
    expected_matrix = np.outer(ground_truth_count, pred_count)

    # Normalization
    conf_mat = conf_mat / conf_mat.sum()
    expected_matrix = expected_matrix / expected_matrix.sum()

    observed = (conf_mat * weighted_matrix).sum()
    expected = (expected_matrix * weighted_matrix).sum()
    return (observed - expected) / (1 - expected)
def evaluation_metrics(y_true, y_pred):
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    specificity = specificity_score(y_true, y_pred, average='weighted')
    conf = confusion_matrix(y_true, y_pred)
    kappa = quadratic_weighted_kappa(conf)

    print('micro acc, pre, recall, f1, specificity, kappa: ', accuracy, precision, recall, f1, specificity, kappa)
    # for i in range(ground_truth.shape[1]):
    #     print('class ',i)
    #     print(accuracy_score(ground_truth[:,i], pred[:,i]), precision_score(ground_truth[:,i], pred[:,i]), 
    #           recall_score(ground_truth[:,i], pred[:,i]), f1_score(ground_truth[:,i], pred[:,i]), specificity_score(ground_truth[:,i], pred[:,i]))
        
    return accuracy, precision, recall, f1, specificity, kappa
best_model = None
best_model_ema = None
best_kappa = -1.
best_kappa_ema = -1.
kappa_list = []
kappa_list_ema = []

best_model_kappa = None
best_model_ema_kappa = None
best_acc = -1.
best_acc_ema = -1.
acc_list = []
acc_list_ema = []

start = time.time()
start_epoch = 0  # start from epoch 0 
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    # iters = np.arange(epochs * niter_per_ep - warmup_iters)
    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

# no_noise = True
# if no_noise == True:
#     coff = [0]*epochs
# else:
#     coff = cosine_scheduler(0.,0.5,epochs,1)
def linear_increase(a, b, n):
    """
    Generates a list of values linearly increasing from a to b over n steps.

    Parameters:
    a (float): The starting value.
    b (float): The ending value.
    n (int): The number of steps.

    Returns:
    list: A list of linearly increasing values from a to b.
    """
    if n < 1:
        raise ValueError("Number of steps n must be at least 1")
    
    step = (b - a) / (n - 1) if n > 1 else 0
    return [a + step * i for i in range(n)]
coff_list = linear_increase(0., 0.5, epochs)
coff_ema = linear_increase(0.9995,0.9999,epochs)
i = 0
for epoch in range(start_epoch, start_epoch+epochs):
    coff = coff_list[i]

    print('epoch ', epoch)
    print('coeff:', coff)
    print('coeff ema:', coff_ema[i])
    train_loss = train(epoch, coeff=coff, coeff_ema=coff_ema[i])
    i += 1
    val_loss, pred, ground_truth, acc, pred_ema = val(epoch)
    print('model')
    # print(classification_report(ground_truth, pred, digits=4))
    accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred)
    kappa_list.append(kappa)
    acc_list.append(accuracy)
    if kappa > best_kappa:
        best_kappa = kappa
        best_model_kappa = deepcopy(net)
        torch.save(best_model_kappa.state_dict(), checkpoint_path+img_type+'_seed'+str(seed)+'_'+network+str(size)+'_kappa.pth')

    if accuracy > best_acc:
        best_acc = accuracy
        best_model = deepcopy(net)
        torch.save(best_model.state_dict(), checkpoint_path+img_type+'_seed'+str(seed)+'_'+network+str(size)+'_accuracy.pth')

   
    print('ema')
    accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred_ema)
    kappa_list_ema.append(kappa)
    acc_list_ema.append(accuracy)
    print('kappa: ', kappa)
    # kappa = accuracy
    if kappa > best_kappa_ema:
        best_kappa_ema = kappa
        best_model_ema_kappa = deepcopy(model_ema.module)
        torch.save(best_model_ema_kappa.state_dict(), checkpoint_path+img_type+'_seed'+str(seed)+'_'+network+str(size)+'_kappa_ema.pth')

    if accuracy > best_acc_ema:
        best_acc_ema = accuracy
        best_model_ema = deepcopy(model_ema.module)
        torch.save(best_model_ema.state_dict(), checkpoint_path+img_type+'_seed'+str(seed)+'_'+network+str(size)+'_accuracy_ema.pth')



In [None]:
print(' ')
print('Validation')
max_kappa = np.amax(kappa_list)
id_kappa = kappa_list.index(max_kappa)
print('max kappa {}, at {}'.format(max_kappa,id_kappa))
max_kappa_ema = np.amax(kappa_list_ema)
id_kappa_ema = kappa_list_ema.index(max_kappa_ema)
print('max kappa ema {}, at {}'.format(max_kappa_ema,id_kappa_ema))

max_kappa = np.amax(acc_list)
id_kappa = acc_list.index(max_kappa)
print('max accuracy {}, at {}'.format(max_kappa,id_kappa))
max_kappa_ema = np.amax(acc_list_ema)
id_kappa_ema = acc_list_ema.index(max_kappa_ema)
print('max accuracy ema {}, at {}'.format(max_kappa_ema,id_kappa_ema))

In [None]:
print(acc_list)

In [None]:

# os.environ["CUDA_VISIBLE_DEVICES"]="2"
# device = 'cuda'
acc= 0
# del optimizer
print('testing with noise')
test_data = CustomDataset(root=img_dir+'test/', vessel=img_dir+'test/', transform=transform_val)
testloader = DataLoader(test_data, batch_size=4, shuffle=False)
print(len(test_data))
print(test_data.class_to_idx)
from sklearn.metrics import roc_auc_score
def test(net, net_ema):
    net.eval()
    net_ema.eval()
    # net = net.to(device)
    # net_ema = net_ema.to(device)

    test_loss = 0
    correct = 0
    total = 0
    pred = []
    ground_truth = []
    pred_ema = []
    score = []
    score_ema = []
    with torch.no_grad():
        for batch_idx, data in enumerate(testloader):

            images = data["image"].to(device)
            labels = data["label"].to(device)
            # lesions = data["vessel"].to(device)

            outputs,_ = net(images, std=1.0)
            _, predicted = outputs.max(1)
            # _, labels = labels.max(1)
            
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            progress_bar(batch_idx, len(testloader), 'Test Loss: %.3f | Test Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            outputs = torch.softmax(outputs, dim=1)
            pred.extend(predicted.data.cpu().numpy().tolist())
            score.extend(outputs.data.cpu().numpy().tolist())
            ground_truth.extend(labels.data.cpu().numpy().flatten().tolist())
            
            outputs_ema,_ = net_ema(images, std=1.0)
            _, predicted_ema = outputs_ema.max(1)
            outputs_ema = torch.softmax(outputs_ema, dim=1)
            score_ema.extend(outputs_ema.data.cpu().numpy().tolist())
            pred_ema.extend(predicted_ema.data.cpu().numpy().tolist())

    return test_loss, np.asarray(pred), np.asarray(ground_truth), acc, np.asarray(pred_ema), np.asarray(score), np.asarray(score_ema)

print('best accuracy')
net = best_model
net_ema = best_model_ema

val_loss, pred, ground_truth, acc, pred_ema, score, score_ema = test(net, net_ema)
print('model')
print(classification_report(ground_truth, pred, digits=4))
accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred)
print("auc ", roc_auc_score(ground_truth, score, average='macro', multi_class='ovr'))
print('ema')
print(classification_report(ground_truth, pred_ema, digits=4))
accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred_ema)
print("auc ", roc_auc_score(ground_truth, score_ema, average='macro', multi_class='ovr'))

print('best kappa')
net = best_model_kappa
net_ema = best_model_ema_kappa
val_loss, pred, ground_truth, acc, pred_ema, score, score_ema = test(net, net_ema)


print('model')
print(classification_report(ground_truth, pred, digits=4))
accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred)
print("auc ", roc_auc_score(ground_truth, score, average='macro', multi_class='ovr'))
print('ema')
print(classification_report(ground_truth, pred_ema, digits=4))
accuracy, precision, recall, f1, spec, kappa = evaluation_metrics(ground_truth, pred_ema)
print("auc ", roc_auc_score(ground_truth, score_ema, average='macro', multi_class='ovr'))



# 

##### 