# 0. Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.models import ResNet50_Weights
import numpy as np
import pandas as pd
from pathlib import Path

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix, classification_report, f1_score, recall_score,
    precision_score, roc_auc_score, average_precision_score, accuracy_score
)

from google.colab import drive
import warnings
import time
import random
from tqdm import tqdm

from torch.cuda.amp import GradScaler, autocast

warnings.filterwarnings("ignore")

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

# 1. Data & Model loading

In [None]:
BASE_DIR = Path("/content/drive/My Drive/졸업프로젝트/TestDataset/")
IMG_MODEL_PATH = BASE_DIR / "MMOTU/resnet50_binary_mmotu_aug.pth"
MULTIMODAL_MODEL_PATH = BASE_DIR / "multimodal_classifier_latefusion.pth"

# Clinical Data
PROCESSED_CLIN_DIR = Path(BASE_DIR) / "Processed_Clinical_Data"
DL_TRAIN_NPZ = PROCESSED_CLIN_DIR / 'dl_train_data.npz'
DL_TEST_NPZ = PROCESSED_CLIN_DIR / 'dl_test_data.npz'

# US image data
PROCESSED_IMG_DIR = BASE_DIR / "MMOTU/ResNet50_ViT_Processed_Data"
TRAIN_IMG_NPZ = PROCESSED_IMG_DIR / "train_augmented2_data.npz"
TEST_IMG_NPZ = PROCESSED_IMG_DIR / "validation_data.npz"

# 2. Data Preprocessing

In [None]:
BINARY_MAPPING = {
    0: 0, # Chocolate cyst (Benign)
    1: 0, # Serous cystadenoma (Benign)
    2: 0, # Teratoma (Benign)
    3: 0, # Theca cell tumor (Benign)
    4: 0, # Simple cyst (Benign)
    5: 0, # Normal ovary (Benign)
    6: 0, # Mucinous cystadenoma (Benign/Borderline)
    7: 1  # High grade serous cystadenocarcinoma (Malignant)
}

In [None]:
class MMOTUUnpairedDataset(Dataset):
    """
    초음파 이미지 데이터(.npz)와 임상 데이터(.npz)를 독립적으로 loading하고,
    Co-training을 위한 random sampling 수행
    """
    def __init__(self, img_npz_path, X_clin_np, Y_clin_np, device='cpu'):
        img_data = np.load(img_npz_path, allow_pickle=True)
        self.img_images = torch.from_numpy(img_data['images']).float()
        self.img_labels = torch.from_numpy(img_data['labels'].astype(np.float32)).float()

        self.clin_data = torch.from_numpy(X_clin_np).float()
        self.clin_labels = torch.from_numpy(Y_clin_np).float()

        self.img_len = len(self.img_images)
        self.clin_len = len(self.clin_data)
        self.max_len = max(self.img_len, self.clin_len)
        print(f"Dataset Initialized. Image Samples: {self.img_len}, Clinical Samples: {self.clin_len}")

    def __len__(self): return self.max_len

    def __getitem__(self, idx):
        img_idx = idx % self.img_len
        clin_idx = random.randint(0, self.clin_len - 1)

        return (
            self.img_images[img_idx],
            self.clin_data[clin_idx],
            self.img_labels[img_idx],
            self.clin_labels[clin_idx]
        )

# 3. Encoders

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, model_path, device):
        super().__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

        self.backbone = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4,
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.output_dim = resnet.fc.in_features  # 2048

        try:
            state_dict = torch.load(model_path, map_location='cpu')
            model_dict = self.backbone.state_dict()
            pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            self.backbone.load_state_dict(model_dict, strict=False)
            print(f"Image Encoder loaded successfully from {model_path.name}")
        except Exception as e:
            print(f"[WARN]: Failed to load pretrained Image Encoder. Training from scratch. Error: {e}")

        # fine tuning : layer 3, 4만 학습 가능하도록 설정
        for name, param in self.backbone.named_parameters():
             if 'layer4' in name or 'layer3' in name:
                 param.requires_grad = True
             else:
                 param.requires_grad = False

    def forward(self, x):
        return torch.flatten(self.backbone(x), 1)

In [None]:
class MultiModalClassifier_LateFusion(nn.Module):
    def __init__(self, img_encoder, clinical_dim, dropout_rate=0.5):
        super().__init__()

        # Image Encoder
        self.image_encoder = img_encoder
        self.image_feature_dim = img_encoder.output_dim

        # Clinical Encoder
        self.clinical_encoder = nn.Sequential(
            nn.Linear(clinical_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.clinical_feature_dim = 128

        # Modality-specific Heads
        self.image_head = nn.Linear(self.image_feature_dim, 1)
        self.clinical_head = nn.Linear(128, 1)

        # Fusion Head
        self.fusion_proj = nn.Sequential(
            nn.Linear(self.image_feature_dim + 128, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        self.fusion_head = nn.Linear(512, 1)

    def forward(self, x_img, x_clin):
        img_feat = self.image_encoder(x_img)
        clin_feat = self.clinical_encoder(x_clin)

        img_out = self.image_head(img_feat)
        clin_out = self.clinical_head(clin_feat)

        fused = torch.cat([img_feat, clin_feat], dim=1)
        fused_feat = self.fusion_proj(fused)
        fusion_out = self.fusion_head(fused_feat)

        return img_out, clin_out, fusion_out

# 4. Evaluation

In [None]:
def calculate_metrics(y_true, y_pred, y_prob, threshold):
    """ 예측 결과(y_pred, y_prob)를 기반으로 모든 요구 지표를 계산합니다. """

    y_pred_thresh = (y_prob >= threshold).astype(int)

    # Confusion Matrix 계산
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred_thresh, labels=[0, 1]).ravel()

    # 1. Accuracy
    accuracy = accuracy_score(y_true, y_pred_thresh)
    # 2. Recall (Sensitivity)
    recall = recall_score(y_true, y_pred_thresh, zero_division=0)
    # 3. Precision
    precision = precision_score(y_true, y_pred_thresh, zero_division=0)
    # 4. F1-score
    f1 = f1_score(y_true, y_pred_thresh, zero_division=0)
    # 5. ROC-AUC
    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = np.nan
    # 6. PR-AUC (Average Precision Score)
    pr_auc = average_precision_score(y_true, y_prob)
    # 7. Specificity (TNR)
    # TNR = TN / (TN + FP)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    return {
        "Acc": accuracy,
        "Recall (Sensitivity)": recall,
        "Precision": precision,
        "F1 Score": f1,
        "ROC-AUC": auc,
        "PR-AUC": pr_auc,
        "Specificity (TNR)": specificity,
        "Confusion Matrix": (tn, fp, fn, tp)
    }

@torch.no_grad()
def evaluate_model_full(model, loader, device, threshold=0.5):
    model.eval()
    y_true, y_prob = [], []

    for x_img, x_clin, y_img, y_clin in loader:
        x_img, x_clin = x_img.to(device), x_clin.to(device)

        _, _, fusion_out = model(x_img, x_clin)
        probs = torch.sigmoid(fusion_out)

        y_true.extend(y_img.cpu().numpy().ravel())
        y_prob.extend(probs.cpu().numpy().ravel())

    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob >= threshold).astype(int) # 예측 결과 (임계값 적용)

    metrics = calculate_metrics(y_true, y_pred, y_prob, threshold)

    print(f"\n--- Evaluation at Threshold {threshold} ---")
    print(f"Accuracy: {metrics['Acc']:.4f} | F1 Score: {metrics['F1 Score']:.4f} | ROC-AUC: {metrics['ROC-AUC']:.4f}")
    print(f"Recall: {metrics['Recall (Sensitivity)']:.4f} | Precision: {metrics['Precision']:.4f} | Specificity (TNR): {metrics['Specificity (TNR)']:.4f}")
    print(f"PR-AUC: {metrics['PR-AUC']:.4f}")
    print("\nConfusion Matrix (TN, FP, FN, TP):", metrics['Confusion Matrix'])
    # Classification Report 출력 (세부 정보)
    print("\nClassification report:")
    print(classification_report(y_true, y_pred, target_names=['Benign (0)', 'Malignant (1)'], digits=4, zero_division=0))

# 5. Main execution

In [None]:
print("Loading pre-processed data from NPZ files...")

# Clinical DL Data
try:
    clin_train_data = np.load(DL_TRAIN_NPZ)
    clin_test_data = np.load(DL_TEST_NPZ)

    X_clin_train = clin_train_data['X_train']
    Y_clin_train = clin_train_data['Y_train']
    X_clin_test = clin_test_data['X_test']
    Y_clin_test = clin_test_data['Y_test']

    CLINICAL_DATA_DIM = X_clin_train.shape[1]
    print(f"Clinical DL data loaded. Feature Dim: {CLINICAL_DATA_DIM}")

except FileNotFoundError:
    print(f"[ERROR] Clinical NPZ files not found. Run clinical_data_prep_split_fixed.py first!")
    exit()

# Image Data
train_dataset = MMOTUUnpairedDataset(TRAIN_IMG_NPZ, X_clin_train, Y_clin_train, device=device)
test_dataset = MMOTUUnpairedDataset(TEST_IMG_NPZ, X_clin_test, Y_clin_test, device=device)

BATCH_SIZE = 64
NUM_WORKERS = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

Loading pre-processed data from NPZ files...
Clinical DL data loaded. Feature Dim: 37
Dataset Initialized. Image Samples: 1912, Clinical Samples: 160080
Dataset Initialized. Image Samples: 469, Clinical Samples: 40020


In [None]:
image_encoder = ImageEncoder(IMG_MODEL_PATH, device).to(device)

multi_model = MultiModalClassifier_LateFusion(
    img_encoder=image_encoder,
    clinical_dim=CLINICAL_DATA_DIM
).to(device)

# Loss Weights & Optimizer 설정
neg = (train_dataset.img_labels == 0).sum().item()
pos = (train_dataset.img_labels == 1).sum().item()
pos_weight_value = (neg / pos) * 2.5
pos_weight = torch.tensor([pos_weight_value], device=device)

criterion_img = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion_clin = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion_fusion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

λ_img, λ_clin, λ_fusion = 1.0, 1.0, 1.0

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, multi_model.parameters()),
    lr=5e-5, weight_decay=0.01
)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 178MB/s]


Image Encoder loaded successfully from resnet50_binary_mmotu_aug.pth


In [None]:
NUM_EPOCHS = 5
print(f"\nStarting MultiModal Binary Training for {NUM_EPOCHS} epochs ...")
print(f"Total Trainable Params: {sum(p.numel() for p in multi_model.parameters() if p.requires_grad) / 1e6:.2f} Million")

scaler = GradScaler()

for epoch in range(1, NUM_EPOCHS + 1):
    multi_model.train()
    total_loss = 0
    start_time = time.time()

    for x_img, x_clin, y_img, y_clin in tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}"):
        x_img, x_clin = x_img.to(device), x_clin.to(device)
        y_img = y_img.unsqueeze(1).to(device)
        y_clin = y_clin.unsqueeze(1).to(device)

        optimizer.zero_grad()

        with autocast():
            img_out, clin_out, fusion_out = multi_model(x_img, x_clin)

            loss_img = criterion_img(img_out, y_img)
            loss_clin = criterion_clin(clin_out, y_clin)
            loss_fusion = criterion_fusion(fusion_out, y_img)  # or y_clin

            loss = (
                λ_img * loss_img +
                λ_clin * loss_clin +
                λ_fusion * loss_fusion
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    end_time = time.time()

    print(f"[Epoch {epoch:02d}] Time: {end_time - start_time:.1f}s | Avg Loss: {avg_loss:.4f}")

print(f"\nSaving final model state dict to {MULTIMODAL_MODEL_PATH}")
torch.save(multi_model.state_dict(), MULTIMODAL_MODEL_PATH)
print("Final multimodal classifier saved successfully.")

print("\n=== FINAL MULTIMODAL EVALUATION (Threshold Tuning) ===")
evaluate_model_full(multi_model, test_loader, device, threshold=0.5)
evaluate_model_full(multi_model, test_loader, device, threshold=0.4)
evaluate_model_full(multi_model, test_loader, device, threshold=0.3)
evaluate_model_full(multi_model, test_loader, device, threshold=0.2)
evaluate_model_full(multi_model, test_loader, device, threshold=0.1)


Starting MultiModal Binary Training for 5 epochs ...
Total Trainable Params: 1.16 Million


Epoch 1/5: 100%|██████████| 2501/2501 [03:54<00:00, 10.65it/s]


[Epoch 01] Time: 234.8s | Avg Loss: 1.8194


Epoch 2/5: 100%|██████████| 2501/2501 [03:49<00:00, 10.89it/s]


[Epoch 02] Time: 229.7s | Avg Loss: 1.4031


Epoch 3/5: 100%|██████████| 2501/2501 [03:50<00:00, 10.86it/s]


[Epoch 03] Time: 230.4s | Avg Loss: 1.2921


Epoch 4/5: 100%|██████████| 2501/2501 [03:50<00:00, 10.86it/s]


[Epoch 04] Time: 230.4s | Avg Loss: 1.2350


Epoch 5/5: 100%|██████████| 2501/2501 [03:50<00:00, 10.86it/s]


[Epoch 05] Time: 230.4s | Avg Loss: 1.1891

Saving final model state dict to /content/drive/My Drive/졸업프로젝트/TestDataset/multimodal_classifier_latefusion.pth
Final multimodal classifier saved successfully.

=== FINAL MULTIMODAL EVALUATION (Threshold Tuning) ===

--- Evaluation at Threshold 0.5 ---
Accuracy: 0.9531 | F1 Score: 0.3538 | ROC-AUC: 0.7618
Recall: 0.4009 | Precision: 0.3165 | Specificity (TNR): 0.9713
PR-AUC: 0.1855

Confusion Matrix (TN, FP, FN, TP): (np.int64(37628), np.int64(1110), np.int64(768), np.int64(514))

Classification report:
               precision    recall  f1-score   support

   Benign (0)     0.9800    0.9713    0.9757     38738
Malignant (1)     0.3165    0.4009    0.3538      1282

     accuracy                         0.9531     40020
    macro avg     0.6483    0.6861    0.6647     40020
 weighted avg     0.9587    0.9531    0.9557     40020

