In [3]:
# ============================================================
# Mean Teacher (EMA) with CNN (ResNet-18) for MCT Spectrograms
# - 5-fold Stratified CV
# - 30% labeled / 70% unlabeled per fold (for SSL)
# - Spectrogram-safe augs (no flips); ImageNet normalization for CNN
# - EMA teacher, consistency loss (MSE), sigmoid ramp-up
# - Accuracy & Macro-F1 per fold + mean ± std
# ============================================================

import os, random, copy, math, time
from typing import List, Tuple
from collections import defaultdict

import numpy as np
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import models, transforms as T

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm import tqdm

# --------------------------
# Reproducibility & Device
# --------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ============================================================
# Data prep (OPTIONAL: only used if these are not defined yet)
# ============================================================
try:
    all_image_paths
    all_labels
    class_to_idx
    print("Found existing all_image_paths / all_labels / class_to_idx")
except NameError:
    # <<< Change this to your dataset root if needed >>>
    data_dir = r"E:\1 Paper MCT\Cutting Tool Paper\Dataset\cutting tool data\test_data_40_images"
    class_names = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir,d))])
    class_to_idx = {c:i for i,c in enumerate(class_names)}
    all_image_paths, all_labels = [], []
    for c in class_names:
        cdir = os.path.join(data_dir, c)
        for fname in os.listdir(cdir):
            if fname.lower().endswith((".png",".jpg",".jpeg",".bmp",".tif",".tiff")):
                all_image_paths.append(os.path.join(cdir, fname))
                all_labels.append(class_to_idx[c])
    print(f"Scanned {len(all_image_paths)} images across {len(class_to_idx)} classes.")

NUM_CLASSES = len(class_to_idx)
CLASS_NAMES = [k for k,_ in sorted(class_to_idx.items(), key=lambda x:x[1])]

# ============================================================
# Datasets
# ============================================================
class SpectrogramDataset(Dataset):
    def __init__(self, image_paths: List[str], labels: List[int], transform=None, labeled=True):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.labeled = labeled
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)
        if self.labeled:
            return img, torch.tensor(self.labels[idx], dtype=torch.long)
        else:
            return img

# Unlabeled dataset that returns (weak, strong) augmented views
class UnlabeledPairDataset(Dataset):
    def __init__(self, image_paths: List[str], weak_tfm, strong_tfm):
        self.image_paths = image_paths
        self.weak_tfm = weak_tfm
        self.strong_tfm = strong_tfm
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        xw = self.weak_tfm(img)
        xs = self.strong_tfm(img)
        return xw, xs

# ============================================================
# Spectrogram-safe augmentations (no flips/time reversal)
# ============================================================
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Simple SpecAugment-style time/freq masking on tensor
class RandomTimeMask:
    def __init__(self, max_width=0.2, p=0.8): self.max_width, self.p = max_width, p
    def __call__(self, x):  # x: Tensor [C,H,W], W=time
        if random.random() > self.p: return x
        _, H, W = x.shape
        w = int(W * random.uniform(0.05, self.max_width))
        if w <= 0: return x
        t0 = random.randint(0, max(0, W-w))
        x[:, :, t0:t0+w] = 0
        return x

class RandomFreqMask:
    def __init__(self, max_height=0.2, p=0.8): self.max_height, self.p = max_height, p
    def __call__(self, x):  # H=freq
        if random.random() > self.p: return x
        _, H, W = x.shape
        h = int(H * random.uniform(0.05, self.max_height))
        if h <= 0: return x
        f0 = random.randint(0, max(0, H-h))
        x[:, f0:f0+h, :] = 0
        return x

class AddGaussianNoise:
    def __init__(self, std=0.02, p=0.5): self.std, self.p = std, p
    def __call__(self, x):
        if random.random() > self.p: return x
        return x + torch.randn_like(x) * self.std

IMG_SIZE = 224

# Labeled/validation transforms
weak_transform_resnet = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
val_transform_resnet = weak_transform_resnet  # deterministic, same norm

# Unlabeled strong transform adds SpecAugment-like masks + slight noise
strong_transform_resnet = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    RandomTimeMask(0.20, 0.8),
    RandomFreqMask(0.20, 0.8),
    AddGaussianNoise(0.02, 0.5),
])

# ============================================================
# Model: ResNet-18 (ImageNet pretrained), classifier for NUM_CLASSES
# ============================================================
def make_resnet18(num_classes: int, pretrained: bool = True):
    # Compatible with new & old torchvision APIs
    try:
        weights = models.ResNet18_Weights.DEFAULT if pretrained else None
        model = models.resnet18(weights=weights)
    except AttributeError:
        model = models.resnet18(pretrained=pretrained)
    in_f = model.fc.in_features
    model.fc = nn.Linear(in_f, num_classes)
    return model.to(device)

def logits_from_model(model, x):
    out = model(x)  # torchvision resnet returns logits tensor
    return out

# ============================================================
# Mean Teacher utils: EMA & ramp-up
# ============================================================
def update_ema(student: nn.Module, teacher: nn.Module, alpha: float):
    with torch.no_grad():
        for p_t, p_s in zip(teacher.parameters(), student.parameters()):
            p_t.data.mul_(alpha).add_(p_s.data, alpha=1.0 - alpha)

def sigmoid_rampup(cur_epoch: int, rampup_epochs: int):
    if rampup_epochs == 0: return 1.0
    t = float(np.clip(cur_epoch, 0.0, rampup_epochs))
    return float(math.exp(-5.0 * (1.0 - t / rampup_epochs) ** 2))

# ============================================================
# Training: Mean Teacher (student learns; teacher = EMA of student)
# ============================================================
def train_mean_teacher_resnet(
    student: nn.Module,
    teacher: nn.Module,
    lab_loader: DataLoader,
    unl_loader: DataLoader,
    epochs: int = 60,
    base_lr: float = 0.05,
    weight_decay: float = 5e-4,
    ema_alpha: float = 0.995,
    rampup_epochs: int = 10,
):
    ce = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(student.parameters(), lr=base_lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    mse = nn.MSELoss()

    # Start teacher = student
    teacher.load_state_dict(copy.deepcopy(student.state_dict()))
    teacher.eval()

    unl_iter = iter(unl_loader)
    for ep in range(1, epochs + 1):
        student.train()
        total, correct = 0, 0

        lam = sigmoid_rampup(ep, rampup_epochs)  # 0->1
        for xb, yb in lab_loader:
            xb, yb = xb.to(device), yb.to(device)

            # fetch an unlabeled batch (weak, strong)
            try:
                xw_u, xs_u = next(unl_iter)
            except StopIteration:
                unl_iter = iter(unl_loader)
                xw_u, xs_u = next(unl_iter)
            xw_u, xs_u = xw_u.to(device), xs_u.to(device)

            # supervised loss on labeled minibatch
            opt.zero_grad()
            logits_sup = logits_from_model(student, xb)
            loss_sup = ce(logits_sup, yb)

            # consistency loss: teacher on weak, student on strong
            with torch.no_grad():
                logits_tw = logits_from_model(teacher, xw_u)
                pw = torch.softmax(logits_tw, dim=1)
            logits_ss = logits_from_model(student, xs_u)
            ps = torch.softmax(logits_ss, dim=1)
            loss_cons = mse(ps, pw)

            loss = loss_sup + lam * loss_cons
            loss.backward()
            opt.step()

            # EMA update after each step
            update_ema(student, teacher, alpha=ema_alpha)

            total += yb.size(0)
            correct += (logits_sup.argmax(1) == yb).sum().item()

        scheduler.step()
        print(f"[MT-ResNet] epoch {ep:03d}/{epochs} | sup_acc={correct/total:.4f} | lambda={lam:.2f} | lr={scheduler.get_last_lr()[0]:.5f}")

# ============================================================
# Evaluation helper
# ============================================================
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> Tuple[float, float]:
    model.eval()
    y_true, y_pred = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = logits_from_model(model, xb)
        y_true.extend(yb.numpy())
        y_pred.extend(logits.argmax(1).cpu().numpy())
    y_true = np.array(y_true); y_pred = np.array(y_pred)
    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred, average='macro')
    return acc, f1

# ============================================================
# Cross-validation driver (5-fold, 30% labeled / 70% unlabeled inside each fold)
# ============================================================
def run_cv_mean_teacher_resnet(
    all_paths: List[str], all_lbls: List[int],
    k_folds: int = 5,
    labeled_ratio: float = 0.30,   # 30% labeled
    batch_l: int = 16,
    mu: int = 3,                   # unlabeled:labeled batch multiplier
    epochs: int = 60,
):
    indices = np.arange(len(all_paths))
    labels_np = np.array(all_lbls)
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=SEED)

    fold_acc, fold_f1 = [], []

    for fold, (train_idx, val_idx) in enumerate(skf.split(indices, labels_np), 1):
        print(f"\n========== FOLD {fold}/{k_folds} ==========")

        # Split into train/val file lists
        tr_paths = [all_paths[i] for i in train_idx]
        tr_lbls  = [all_lbls[i]  for i in train_idx]
        va_paths = [all_paths[i] for i in val_idx]
        va_lbls  = [all_lbls[i]  for i in val_idx]

        # Inside the training portion: stratified split into L and U
        L_paths, U_paths, L_lbls, _ = train_test_split(
            tr_paths, tr_lbls, test_size=1.0 - labeled_ratio, stratify=tr_lbls, random_state=SEED
        )

        print("Train/Labeled:", len(L_paths), "Train/Unlabeled:", len(U_paths), "Val:", len(va_paths))

        # Datasets & loaders
        lab_ds = SpectrogramDataset(L_paths, L_lbls, transform=weak_transform_resnet, labeled=True)
        unl_ds = UnlabeledPairDataset(U_paths, weak_tfm=weak_transform_resnet, strong_tfm=strong_transform_resnet)
        val_ds = SpectrogramDataset(va_paths, va_lbls, transform=val_transform_resnet, labeled=True)

        lab_loader = DataLoader(lab_ds, batch_size=batch_l, shuffle=True,  num_workers=0)
        unl_loader = DataLoader(unl_ds, batch_size=batch_l * mu, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_ds, batch_size=batch_l, shuffle=False, num_workers=0)

        # Models
        student = make_resnet18(NUM_CLASSES, pretrained=True)
        teacher = make_resnet18(NUM_CLASSES, pretrained=True)  # will be overwritten by EMA init

        # Train MT
        train_mean_teacher_resnet(
            student, teacher, lab_loader, unl_loader,
            epochs=epochs, base_lr=0.05, weight_decay=5e-4,
            ema_alpha=0.995, rampup_epochs=10
        )

        # Evaluate student on validation
        acc, f1 = evaluate(student, val_loader)
        print(f"Fold {fold} | Acc={acc:.4f} | Macro-F1={f1:.4f}")
        fold_acc.append(acc); fold_f1.append(f1)

    # Summary
    acc_mean, acc_std = np.mean(fold_acc), np.std(fold_acc)
    f1_mean, f1_std   = np.mean(fold_f1),  np.std(fold_f1)
    print("\n=== Mean Teacher (ResNet-18) — 5-fold Summary ===")
    print(f"Accuracy  : {acc_mean:.4f} ± {acc_std:.4f}")
    print(f"Macro-F1  : {f1_mean:.4f} ± {f1_std:.4f}")
    return fold_acc, fold_f1

# ============================================================
# Run it
# (Adjust epochs for a quicker sanity check; 40–60 is typical for ResNet on spectrograms.)
# ============================================================
fold_acc, fold_f1 = run_cv_mean_teacher_resnet(
    all_image_paths, all_labels,
    k_folds=5, labeled_ratio=0.30, batch_l=16, mu=3, epochs=40
)


Using device: cpu
Scanned 280 images across 7 classes.

Train/Labeled: 67 Train/Unlabeled: 157 Val: 56
[MT-ResNet] epoch 001/40 | sup_acc=0.1642 | lambda=0.02 | lr=0.04992
[MT-ResNet] epoch 002/40 | sup_acc=0.1493 | lambda=0.04 | lr=0.04969
[MT-ResNet] epoch 003/40 | sup_acc=0.1940 | lambda=0.09 | lr=0.04931
[MT-ResNet] epoch 004/40 | sup_acc=0.3134 | lambda=0.17 | lr=0.04878
[MT-ResNet] epoch 005/40 | sup_acc=0.2537 | lambda=0.29 | lr=0.04810
[MT-ResNet] epoch 006/40 | sup_acc=0.1343 | lambda=0.45 | lr=0.04728
[MT-ResNet] epoch 007/40 | sup_acc=0.1343 | lambda=0.64 | lr=0.04632
[MT-ResNet] epoch 008/40 | sup_acc=0.2388 | lambda=0.82 | lr=0.04523
[MT-ResNet] epoch 009/40 | sup_acc=0.3433 | lambda=0.95 | lr=0.04401
[MT-ResNet] epoch 010/40 | sup_acc=0.1642 | lambda=1.00 | lr=0.04268
[MT-ResNet] epoch 011/40 | sup_acc=0.2687 | lambda=1.00 | lr=0.04124
[MT-ResNet] epoch 012/40 | sup_acc=0.2388 | lambda=1.00 | lr=0.03969
[MT-ResNet] epoch 013/40 | sup_acc=0.3582 | lambda=1.00 | lr=0.03806
