In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

print("Pytorch version: ", torch.__version__)
print("Torchvision version: ", torchvision.__version__)

Pytorch version:  2.1.2
Torchvision version:  0.16.2


In [5]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from skimage import io
from PIL import Image
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from sklearn.metrics import confusion_matrix, precision_score, recall_score

from torch.utils.tensorboard import SummaryWriter  # TensorBoard writer

from load_ISIC2018 import load_ISIC2018_GT
from dataset_from_df import labeled_dataset_from_path, unlabeled_dataset_from_path
import opts
from models.nets import densenet

In [6]:
from easydict import EasyDict
from datetime import datetime
from pytz import timezone

# 설정 값을 EasyDict로 정의
args = EasyDict({
    "data_root": "/home/sg980429/research/All_data",
    "save_dir": "../save_files",
    'data' : 'ISIC2018',
    "lbl_ratio": 0.1,
    "resize": 224,
    "batch_size": 32,
    "num_workers": 4,
    "gpu_num": 1,
    'gpu': 1,
    'arch': 'DenseNet',
    "drop_rate": 0.5,
    "lr": 1e-3,
    "ft_lr": 1e-4,
    'num_classes':7,
    "num_epochs": 128,
    'RL': 'PPO',
    "episode_size": 16,
    "mini_batch_size": 4,
    "mini_num_epochs": 4,
    "controller_gpu": 2,
    "experiment_num": 789,
    # "experiment_num": datetime.now(timezone('Asia/Seoul')).strftime("%Y%m%d_%H%M%S"),

    "topk" : 50,
    })

device = torch.device(f"cuda:{args.gpu_num}" if torch.cuda.is_available() else "cpu")

In [7]:
data_root = "/home/sg980429/research/All_data"

In [8]:
args.experiment_num = 123 
args.seed = args.experiment_num 

In [9]:
exp_dir = os.path.join(args.save_dir,f'{args.data}_{args.arch}_{args.RL}_{args.experiment_num}')

In [10]:
teacher_model_dir = os.path.join(exp_dir,"teacher_model")

task_predictor_dir = os.path.join(exp_dir,"task_predictor")

df_path = os.path.join(exp_dir,"train_df.pkl")

In [11]:
_, val_df, test_df = load_ISIC2018_GT(data_root)
training_df = pd.read_pickle(df_path)
num_classes = len(test_df['GT'][0])

In [12]:
x = 0.85
indices = training_df[training_df['contribution'] >= x].index

In [13]:
len(indices)

2609

In [14]:
args.seed = args.experiment_num
np.random.seed(args.seed)
train_indexes = np.random.permutation(len(training_df))
val_indexes = np.random.permutation(len(val_df))
test_indexes = np.arange(len(test_df))

num_labels = int(len(training_df)*args.lbl_ratio)
lbl_indexes = train_indexes[:num_labels]
ulb_indexes = train_indexes[num_labels:]

In [15]:
lbl_indexes = np.array(lbl_indexes)  
# indices = np.array(indices)          

# 배열 결합
# combined_array = np.concatenate((lbl_indexes, indices))

In [16]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]


# train_transform = transforms.Compose(
#     [
#         transforms.RandomResizedCrop((args.resize, args.resize), scale=(0.8, 1.0)),  # Crop with a conservative scale to retain lesions
#         transforms.RandomHorizontalFlip(),  # Horizontal flip
#         transforms.RandomVerticalFlip(),  # Vertical flip
#         transforms.RandomRotation(degrees=15),  # Rotate by a small degree
#         transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Slight color adjustments
#         transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),  # Gentle Gaussian Blur
#         transforms.ToTensor(),
#         transforms.Normalize(mean, std),
#     ]
# )
# train_transform = transforms.Compose(
#     [
#         transforms.RandomResizedCrop((args.resize,args.resize), scale=(0.2, 1.0)),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize(mean, std),
#     ]
# )
train_transform = transforms.Compose(
    [
        transforms.RandomResizedCrop((args.resize, args.resize), scale=(0.8, 1.0)),  # Crop with a conservative scale to retain lesions
        transforms.RandomHorizontalFlip(),  # Horizontal flip
        transforms.RandomVerticalFlip(),  # Vertical flip
        transforms.RandomRotation(degrees=15),  # Rotate by a small degree
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),  # Slight color adjustments
        transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),  # Gentle Gaussian Blur
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)



test_transform = transforms.Compose(
    [
        transforms.Resize((args.resize,args.resize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

In [17]:

lbl_dataset = labeled_dataset_from_path(training_df,lbl_indexes,transforms=train_transform)

val_dataset = labeled_dataset_from_path(val_df, val_indexes, transforms=test_transform)
test_dataset = labeled_dataset_from_path(test_df, test_indexes, transforms=test_transform)

In [18]:
LabeledTrainLoader = DataLoader(lbl_dataset, 
                  batch_size=args.batch_size,
                  num_workers=args.num_workers)

ValLoader = DataLoader(val_dataset, 
                  batch_size=args.batch_size,
                  num_workers=args.num_workers)

TestLoader = DataLoader(test_dataset, 
                  batch_size=args.batch_size,
                  num_workers=args.num_workers)

In [19]:
def create_densenet_model(pretrained, drop_rate, num_classes):
    backbone = densenet.densenet121(pretrained=pretrained, drop_rate=drop_rate)
    in_features = backbone.classifier.in_features 
    backbone.classifier = nn.Linear(in_features, num_classes)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
    return model

In [20]:
model1 = create_densenet_model(pretrained=True, drop_rate=args.drop_rate,num_classes=args.num_classes)
ImageNet_PreTrained_weights = copy.deepcopy(model1.state_dict())

In [25]:
def evaluate_model_performance(test_model,TestLoader,device):
    test_model.to(device)
    test_model.eval() 
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():  
        for test_inputs, test_labels, _ in tqdm(TestLoader):
            test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)

            outputs, _ = test_model(test_inputs)

            probs = torch.softmax(outputs, dim=1)  
            _, preds = torch.max(outputs, 1)  
    

            all_labels.extend(np.argmax(test_labels.cpu().numpy(), axis=1))
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Convert to numpy arrays for metric calculation
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    

    test_accuracy = accuracy_score(all_labels, all_preds)
    test_f1_score = f1_score(all_labels, all_preds, average='macro')
    test_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')

    
    

    return test_accuracy, test_f1_score, test_auc

In [None]:

exp_num = [123,234,345,456,567,678,789,890,1111,2222,3333,4444]
methods = ['confidence', 'contribution']
xx = 0
method = methods[xx]
# exp_num = [123]
num_classes = len(class_names)

colors = plt.cm.rainbow(np.linspace(0, 1, len(class_names)))

for method in methods:
    for x in exp_num:
        args.experiment_num = x
        exp_dir = os.path.join(args.save_dir,f'{args.data}_{args.arch}_{args.RL}_{args.experiment_num}')
        

        df_path = os.path.join(exp_dir,"train_df.pkl")
        training_df = pd.read_pickle(df_path)


        
        ulb_df = training_df[~training_df['contribution'].isna()].copy()
        pseudo_labels_one_hot_np = np.array(ulb_df['labels'].tolist())
        pseudo_labels_np = np.argmax(list(pseudo_labels_one_hot_np), axis=1)

        

        real_labels = np.array(ulb_df['GT'].tolist())
        real_labels_np = np.argmax(list(real_labels), axis=1)

        baseline_precision = precision_score(real_labels_np, pseudo_labels_np, average=None, labels=range(len(class_names)), zero_division=0)
        baseline_recall = recall_score(real_labels_np, pseudo_labels_np, average=None, labels=range(len(class_names)), zero_division=0)

        
        threshold_scores = np.linspace(0, 1, 21)

        # 클래스별 성능을 저장할 딕셔너리
        class_performance = {class_name: {'precisions': [], 'recalls': []} for class_name in class_names}
        os.makedirs(os.path.join(args.save_dir,'figures',f'bar_{x}'), exist_ok=True) 
        
        for i, threshold in enumerate(threshold_scores):
            ulb_score = np.array(ulb_df[method].tolist())
            mask = ulb_score >= threshold

In [26]:
metrics = ["loss","acc", "F1", "auc"]
m = 3

In [23]:
# #%%


# teacher_model_path = os.path.join(teacher_model_dir,f"best_{metrics[m]}.pth")
# check_point = torch.load(teacher_model_path)

# model1.load_state_dict(check_point['weights'])

# test_accuracy, test_f1_score, test_auc = evaluate_model_performance(model1,TestLoader,device)
# print(test_accuracy, test_f1_score, test_auc)

In [27]:
#%%
task_predictor_path = os.path.join(task_predictor_dir,f"best_{metrics[m]}.pth")
check_point = torch.load(task_predictor_path)

model1.load_state_dict(check_point['weights'])

test_accuracy, test_f1_score, test_auc = evaluate_model_performance(model1,TestLoader,device)
print(test_accuracy, test_f1_score, test_auc)

100%|██████████| 48/48 [00:02<00:00, 17.39it/s]

0.623015873015873 0.22426168294902768 0.8319069984667792





In [28]:
check_point['val_auc']

0.9226042012261549

In [25]:
# combined_array

# combined_dataset = labeled_dataset_from_path(training_df,combined_array,transforms=train_transform)

# CombinedTrainLoader = DataLoader(combined_dataset, 
#                   batch_size=args.batch_size,
#                   num_workers=args.num_workers)

In [26]:
from torch.optim.lr_scheduler import CosineAnnealingLR

FT_log_dir = os.path.join(exp_dir,"FT_log")


optimizer = torch.optim.Adam(
            model1.parameters(),
            lr=args.ft_lr,
            betas=(0.9, 0.99),
            eps=0.1,
        )
# optimizer = optim.Adam(model1.classifier.parameters(), lr=1e-4)  # 분류기만 학습

# scheduler = CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5)

criterion = nn.CrossEntropyLoss()

writer = SummaryWriter(FT_log_dir)  # TensorBoard SummaryWriter
model1.to(device)
best_loss = 100
best_acc = 0
best_F1 = 0
best_auc = 0
for epoch in tqdm(range(args.num_epochs)):

    running_loss = 0.0
    model1.train() 
    for inputs, labels, idx in LabeledTrainLoader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model1(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    avg_train_loss = running_loss / len(LabeledTrainLoader)
    writer.add_scalar('Loss/Train', avg_train_loss, epoch)

    model1.eval()  # Set model to evaluation mode
    val_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for val_inputs, val_labels, _ in ValLoader:
            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
            val_outputs = model1(val_inputs)
            loss = criterion(val_outputs, val_labels)
            val_loss += loss.item()
            
            # Get the predicted labels and probabilities
            probs = torch.softmax(val_outputs, dim=1)
            _, preds = torch.max(val_outputs, 1)

                            
            # all_labels.extend(val_labels.cpu().numpy())
            all_labels.extend(np.argmax(val_labels.cpu().numpy(), axis=1))
            all_preds.extend(preds.cpu().numpy())

            all_probs.extend(probs.cpu().numpy())
    

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)


    avg_val_loss = val_loss / len(ValLoader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    val_f1_score = f1_score(all_labels, all_preds, average='macro')
    val_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')


    writer.add_scalar('Loss/Validation', val_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_accuracy, epoch)
    writer.add_scalar('F1-Score/Validation', val_f1_score, epoch)
    writer.add_scalar('AUC/Validation', val_auc, epoch)
    
    state_dict = {
        "epoch": epoch + 1,
        "weights": model1.state_dict(),
        "val_loss": avg_val_loss,
        "val_accuracy": val_accuracy,
        "val_f1_score": val_f1_score,
        "val_auc": val_auc,
    }



    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        # torch.save(state_dict, os.path.join(teacher_model_dir,"best_loss.pth"))
    if val_accuracy >  best_acc:
        best_acc = val_accuracy
        # torch.save(state_dict, os.path.join(teacher_model_dir,"best_acc.pth"))
    if  val_f1_score > best_F1:
        best_F1 = val_f1_score
        # torch.save(state_dict, os.path.join(teacher_model_dir,"best_F1.pth"))
    if val_auc > best_auc:
        best_auc = val_auc
        best_model_wts = copy.deepcopy(model1.state_dict())
        # torch.save(state_dict, os.path.join(teacher_model_dir,"best_auc.pth"))
# scheduler.step()

writer.close()  

100%|██████████| 128/128 [58:35<00:00, 27.47s/it]


In [27]:
best_auc


0.9408513466926806

In [30]:
test_accuracy, test_f1_score, test_auc = evaluate_model_performance(model1,ValLoader,device)
print(test_accuracy, test_f1_score, test_auc)

100%|██████████| 7/7 [00:01<00:00,  6.50it/s]

0.7927461139896373 0.46851636958041365 0.9408513466926806





In [28]:
model1.load_state_dict(best_model_wts)

test_accuracy, test_f1_score, test_auc = evaluate_model_performance(model1,TestLoader,device)
print(test_accuracy, test_f1_score, test_auc)

100%|██████████| 48/48 [00:03<00:00, 13.43it/s]

0.6574074074074074 0.3956114658758397 0.865600493059386





In [29]:
ValLoader

<torch.utils.data.dataloader.DataLoader at 0x7f2bf93e8130>