In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import random 
import torch
import sys
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
from totalsegmentator.libs import download_pretrained_weights
import nibabel as nib
from monai.networks.nets import BasicUNet
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
os.getcwd()

In [None]:
# seed
def seed_everything(random_seed = 42):
    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

seed_everything(42)

# path
dir_path = '/home/snuhub-user/workspace/updown_share/Pleural_CT/'

# Hyper Parameter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
Epochs = 75
Learning_rate = 0.005
Batch_size = 32
Early_stopping_num = 25

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor() # Future modifications and developments
])
train_augmentation = transforms.Compose([
    transforms.RandomHorizontalFlip() # Future modifications and developments
])

In [None]:
class SegDataset(Dataset):
    def __init__(self, df, augmentation=None, preprocessing=None, train_mode=True):
        self.df = df
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.train_mode = train_mode
        self.slices = []

        if train_mode:
            for i in tqdm(range(len(self.df))):
                data_path = self.df.iloc[i, 0]
                file_path = os.path.join(dir_path, 'data', 'SegCT_Anony10Percent', data_path)
                image = nib.load(os.path.join(file_path, 'original.nii.gz'))
                mask = nib.load(os.path.join(file_path, 'pleural_effusion.nii.gz'))
                image_data = image.get_fdata()
                mask_data = mask.get_fdata()

                for j in range(int(image_data.shape[2])//2):
                    self.slices.append((image_data[:, :, 2*j], mask_data[:, :, 2*j])) # Skip and extract one by one
        else:
            for i in tqdm(range(len(self.df))):
                data_path = self.df.iloc[i, 0]
                file_path = os.path.join(dir_path, 'data', 'SegCT_Anony10Percent', data_path)
                image = nib.load(os.path.join(file_path, 'original.nii.gz'))
                mask = nib.load(os.path.join(file_path, 'pleural_effusion.nii.gz'))
                image_data = image.get_fdata()
                mask_data = mask.get_fdata()

                for j in range(int(image_data.shape[2])):
                    self.slices.append((image_data[:, :,j], mask_data[:, :,j])) # Extract all 

    def __getitem__(self, index):
        img, mask = self.slices[index]

        mask = np.expand_dims(mask, axis=0)  # add Channel dimension 
        
        if self.preprocessing:
            img = self.preprocessing(img)
            
        mask_tensor = torch.from_numpy(mask).long()
        

        return img, mask_tensor

    def __len__(self):
        return len(self.slices)



In [None]:
model =BasicUNet(spatial_dims=2, in_channels = 1, out_channels = 1, features=(32, 64, 128, 256, 512, 32))
model.to(device)

In [None]:
def dice_score(mask, target):
    epsilon = 1e-8
    intersection = torch.sum(mask * target, dim=[1, 2, 3])  # 각 배치 및 각 차원에 대해 합산
    union = torch.sum(mask, dim=[1, 2, 3]) + torch.sum(target, dim=[1, 2, 3])
    
    dice = (2.0 * intersection + epsilon) / (union + epsilon)
    
    return dice.mean() # 차원을 평균 내기에 배치별 Dice 값이 나옴

In [None]:
class Trainer:
    def __init__(self, train_dataset, valid_dataset, model):
        self.epochs = 75
        self.learning_rate = Learning_rate
        self.batch_size = Batch_size
        self.early_stopping_num = Early_stopping_num
        
        self.model = model
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        
        self.train_losses = []
        self.valid_losses = []
        self.val_dice_scores = []

    def train(self):
        train_loader = DataLoader(self.train_dataset, shuffle=True, batch_size=self.batch_size)
        valid_loader = DataLoader(self.valid_dataset, shuffle=False, batch_size=22)
        criterion = nn.BCELoss()  # Consider changing to Dice loss
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        scheduler = CosineAnnealingLR(optimizer, T_max=25, eta_min=0.00001)

        early_stopping_counter = 0 
        top_dice_score = 0
        top_valid_loss = 0
        
        for epoch in tqdm(range(self.epochs)):
            self.model.train()
            running_loss = 0
            print(f'Epoch {epoch + 1} / {self.epochs}')
            for _, data in tqdm(enumerate(train_loader), mininterval=30.0):
                inputs, masks = data
                inputs, masks = inputs.float().to(device), masks.float().to(device)
                batch_size = inputs.size(0)
                
                optimizer.zero_grad()
                outputs = self.model(inputs)
                outputs = torch.sigmoid(outputs)
                loss = criterion(outputs, masks)
                
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * batch_size
                total_samples += batch_size
                
            train_loss = running_loss / total_samples
            self.train_losses.append(train_loss)
            print(f'Training loss: {round(train_loss, 4)}')
            
            self.model.eval()
            val_loss_total = 0
            val_score_total = 0
            
            with torch.no_grad():
                for _, val_data in enumerate(valid_loader):
                    val_inputs, val_masks = val_data
                    val_inputs, val_masks = val_inputs.float().to(device), val_masks.float().to(device)
                    val_batch_size = val_inputs.size(0)
                    
                    val_outputs = self.model(val_inputs)
                    val_outputs = torch.sigmoid(val_outputs)
                    y_pred = (val_outputs > 0.5).float()
                    val_score = dice_score(y_pred, val_masks)
                    val_loss = criterion(val_outputs, val_masks)
                    
                    val_loss_total += val_loss.item() * val_batch_size
                    val_score_total += val_score * val_batch_size 
                    total_val_samples += val_batch_size  

            val_loss = val_loss_total / total_val_samples  
            val_score = val_score_total / total_val_samples
            self.valid_losses.append(val_loss)
            self.val_dice_scores.append(val_score)
            
            
            print(f'Dice Score: {val_score} | Valid Loss: {val_loss}')
            
            scheduler.step() # 스케쥴러 추가
            
            if val_loss > top_valid_loss:
                top_valid_loss = val_loss
                torch.save(self.model.state_dict(), f'/home/snuhub-user/workspace/updown_share/Pleural_CT/weights/top_second_epoch_{epoch}.pth')
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1
            
            if early_stopping_counter >= self.early_stopping_num:
                print("Early stopping triggered.")
                break

    def plot(self):
        epochs = range(1, len(self.train_losses) + 1)
        plt.plot(epochs,self.train_losses, label='Training Loss')
        plt.plot(epochs,self.valid_losses, label='Validation Loss')
        val_dice_scores = [score.cpu().numpy() for score in self.val_dice_scores]
        plt.plot(epochs,val_dice_scores, label='Validation Dice Score')
        plt.xlabel('Epoch')
        plt.ylabel('Loss/Dice Score')
        plt.title('Loss/Score Visualization')
        plt.legend()
        plt.grid(True)
        plt.show()


In [None]:
from sklearn.model_selection import train_test_split
df_raw = pd.read_excel(dir_path + 'PF_CT_10percent.xlsx')
df_raw
df_exist = df_raw[df_raw.Exist == 1].reset_index(drop=True)
df_exist
df_train, df_validtest = train_test_split(df_exist, test_size=25,stratify=df_exist.Bot_miss , random_state = 42)
df_valid, df_test = train_test_split(df_validtest, test_size = 3,stratify=df_validtest.Bot_miss , random_state = 42)
print(df_train.Bot_miss.value_counts())
print(df_validtest.Bot_miss.value_counts())
print(df_valid.Bot_miss.value_counts())
print(df_test.Bot_miss.value_counts())

In [None]:
train_dataset = SegDataset(df_train,train_transform)

In [None]:
valid_dataset = SegDataset(df_valid,train_transform)

In [None]:
trainer = Trainer(train_dataset, valid_dataset, model)

In [None]:
trainer.train()
trainer.plot()