<a href="https://colab.research.google.com/github/danimacaya/AIInternational/blob/main/classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import pandas as pd
import warnings
warnings.filterwarnings("ignore")


current_dir = "/content/drive/MyDrive/Assignment_25_26"
os.chdir(current_dir)

meta_path = os.path.join(current_dir, "training_metadata.xlsx")
df = pd.read_excel(meta_path)  # columns: US, MASK, LABEL
print(df.head())
print(df['LABEL'].value_counts())  # 0=benign, 1=malignant, 2=normal0}2


Mounted at /content/drive
      US        MASK  LABEL
0  1.png  1_mask.png      2
1  2.png  2_mask.png      2
2  3.png  3_mask.png      2
3  4.png  4_mask.png      2
4  5.png  5_mask.png      2
LABEL
0    679
2    460
1    364
Name: count, dtype: int64


In [None]:
import re

def extract_patient_id(us_name: str) -> str:
    base = os.path.splitext(us_name)[0]  # e.g. '1040_3'
    # Split on first underscore, take left part
    return base.split('_')[0]

df['PATIENT_ID'] = df['US'].astype(str).apply(extract_patient_id)

print(df[['US', 'PATIENT_ID']].head(20))
print(df['PATIENT_ID'].nunique(), "unique patients")


        US PATIENT_ID
0    1.png          1
1    2.png          2
2    3.png          3
3    4.png          4
4    5.png          5
5    6.png          6
6    7.png          7
7    8.png          8
8    9.png          9
9   10.png         10
10  11.png         11
11  12.png         12
12  13.png         13
13  14.png         14
14  15.png         15
15  16.png         16
16  17.png         17
17  18.png         18
18  19.png         19
19  20.png         20
968 unique patients


In [None]:
label_map = {0: 'benign', 1: 'malignant', 2: 'normal'}
df['LABEL_NAME'] = df['LABEL'].map(label_map)

print("Global label counts:")
print(df['LABEL_NAME'].value_counts())


Global label counts:
LABEL_NAME
benign       679
normal       460
malignant    364
Name: count, dtype: int64


In [None]:
per_patient = df.groupby('PATIENT_ID')['LABEL'].agg(['count', 'nunique'])
print(per_patient.head(20))


            count  nunique
PATIENT_ID                
1               1        1
10              1        1
100             1        1
1000            1        1
1001            1        1
1003            1        1
1007            1        1
1008            1        1
1009            1        1
101             1        1
1010            1        1
1012            1        1
1015            1        1
1016            1        1
102             1        1
1020            1        1
1021            1        1
1025            1        1
1026            1        1
103             1        1


In [None]:
# Representative label per patient (mode)
rep = df.groupby('PATIENT_ID')['LABEL'].agg(lambda x: x.value_counts().idxmax())
rep = rep.reset_index().rename(columns={'LABEL': 'REP_LABEL'})
print(rep['REP_LABEL'].value_counts())


REP_LABEL
0    537
1    286
2    145
Name: count, dtype: int64


In [None]:
from sklearn.model_selection import StratifiedGroupKFold

X = rep['PATIENT_ID'].values
y = rep['REP_LABEL'].values
groups = rep['PATIENT_ID'].values

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
train_pids, val_pids = next(cv.split(X, y, groups))

train_patient_ids = rep.iloc[train_pids]['PATIENT_ID']
val_patient_ids   = rep.iloc[val_pids]['PATIENT_ID']

df_train = df[df['PATIENT_ID'].isin(train_patient_ids)].reset_index(drop=True)
df_val   = df[df['PATIENT_ID'].isin(val_patient_ids)].reset_index(drop=True)

print("Train images:", len(df_train), "Val images:", len(df_val))
print("Train label distribution:\n", df_train['LABEL_NAME'].value_counts())
print("Val label distribution:\n", df_val['LABEL_NAME'].value_counts())


Train images: 1177 Val images: 326
Train label distribution:
 LABEL_NAME
benign       542
normal       355
malignant    280
Name: count, dtype: int64
Val label distribution:
 LABEL_NAME
benign       137
normal       105
malignant     84
Name: count, dtype: int64


In [None]:
df.to_csv("metadata_full_with_patient.csv", index=False)
df_train.to_csv("metadata_train.csv", index=False)
df_val.to_csv("metadata_val.csv", index=False)



In [None]:
!pip install -q timm albumentations==1.4.3

import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import f1_score, classification_report, confusion_matrix
from tqdm.auto import tqdm

# ----------------- CONFIG -----------------
class Config:
    img_dir = 'training_images'           # folder with .png images [file:1]
    train_meta = 'metadata_train.csv'     # from preprocessing step
    val_meta   = 'metadata_val.csv'

    output_dir = 'outputs_cls'
    model_name = 'tf_efficientnetv2_s'
    num_classes = 3                       # 0=benign,1=malignant,2=normal [file:1]
    img_size = 384

    epochs = 40
    batch_size = 16
    lr = 1e-4
    weight_decay = 1e-5
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_workers = 4

    seed = 42

config = Config()
os.makedirs(config.output_dir, exist_ok=True)

# ----------------- SEED -----------------
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.seed)

# ----------------- TRANSFORMS -----------------
def get_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([
            A.Resize(config.img_size, config.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05,
                               scale_limit=0.1,
                               rotate_limit=15,
                               border_mode=0,
                               p=0.4),
            A.OneOf([
                A.GaussNoise(var_limit=(5.0, 25.0), p=0.5),
                A.GaussianBlur(blur_limit=(3, 5), p=0.5),
            ], p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(config.img_size, config.img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

# ----------------- DATASET -----------------
class BreastUSDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['US']          # e.g. '1.png'
        label = int(row['LABEL'])     # 0/1/2

        img_path = os.path.join(self.img_dir, img_name)
        # images are grayscale; convert to RGB by repetition for timm models
        image = Image.open(img_path).convert('L')
        image = np.array(image)
        image = np.stack([image, image, image], axis=2)  # H,W -> H,W,3

        if self.transform:
            aug = self.transform(image=image)
            image = aug['image']

        return image, label

# ----------------- MODEL -----------------
class USClassifier(nn.Module):
    def __init__(self, model_name=config.model_name,
                 num_classes=config.num_classes, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0
        )
        in_features = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(in_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        feats = self.backbone(x)
        out = self.classifier(feats)
        return out

# ----------------- TRAIN / VALIDATE -----------------
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []

    pbar = tqdm(loader, desc='Train', leave=False)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(1)
        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())

        pbar.set_postfix(loss=running_loss / len(pbar))

    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')
    return epoch_loss, epoch_f1

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        pbar = tqdm(loader, desc='Val', leave=False)
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            preds = outputs.argmax(1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(labels.detach().cpu().numpy())

    epoch_loss = running_loss / len(loader)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')
    return epoch_loss, epoch_f1, np.array(all_preds), np.array(all_labels)

# ----------------- MAIN TRAINING -----------------
def main():
    print(f"Device: {config.device}")
    print(f"Backbone: {config.model_name}")

    train_df = pd.read_csv(config.train_meta)
    val_df   = pd.read_csv(config.val_meta)

    print("Train size:", len(train_df), "Val size:", len(val_df))
    print("Train label counts:\n", train_df['LABEL'].value_counts())
    print("Val label counts:\n", val_df['LABEL'].value_counts())

    train_tfms = get_transforms('train')
    val_tfms   = get_transforms('val')

    train_ds = BreastUSDataset(train_df, config.img_dir, transform=train_tfms)
    val_ds   = BreastUSDataset(val_df,   config.img_dir, transform=val_tfms)

    config.num_workers = 0  # or even 0
    train_loader = DataLoader(
        train_ds,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    model = USClassifier().to(config.device)

    # class weights to mitigate imbalance
    class_counts = train_df['LABEL'].value_counts().sort_index().values
    class_weights = 1.0 / (class_counts + 1e-6)
    class_weights = class_weights / class_weights.sum() * len(class_counts)
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(config.device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.AdamW(model.parameters(),
                            lr=config.lr,
                            weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.epochs, eta_min=1e-6
    )

    best_f1 = 0.0
    best_path = os.path.join(config.output_dir, 'best_cls_model.pth')
    history = {'train_loss': [], 'val_loss': [], 'train_f1': [], 'val_f1': []}

    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch+1}/{config.epochs}")
        train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, config.device)
        val_loss, val_f1, val_preds, val_labels = validate_epoch(model, val_loader, criterion, config.device)

        scheduler.step()

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_f1'].append(train_f1)
        history['val_f1'].append(val_f1)

        print(f"Train Loss: {train_loss:.4f} | F1: {train_f1:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | F1: {val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
                'best_f1': best_f1,
            }, best_path)
            print(f"✓ Saved best model (F1={best_f1:.4f})")

    # reload best model for final report metrics
    ckpt = torch.load(best_path, map_location=config.device)
    model.load_state_dict(ckpt['model_state_dict'])
    history = ckpt['history']
    print("\nBest Val F1:", ckpt['best_f1'])

    _, _, val_preds, val_labels = validate_epoch(model, val_loader, criterion, config.device)
    print("\nClassification report (val):")
    target_names = ['benign (0)', 'malignant (1)', 'normal (2)']
    print(classification_report(val_labels, val_preds, target_names=target_names))

    cm = confusion_matrix(val_labels, val_preds, labels=[0,1,2])
    print("Confusion matrix (val):\n", cm)

    # save history for plotting curves for the report
    hist_df = pd.DataFrame(history)
    hist_df.to_csv(os.path.join(config.output_dir, 'cls_training_history.csv'), index=False)

if __name__ == "__main__":
    main()


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/137.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.0/137.0 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDevice: cuda
Backbone: tf_efficientnetv2_s
Train size: 1177 Val size: 326
Train label counts:
 LABEL
0    542
2    355
1    280
Name: count, dtype: int64
Val label counts:
 LABEL
0    137
2    105
1     84
Name: count, dtype: int64


model.safetensors:   0%|          | 0.00/86.5M [00:00<?, ?B/s]


Epoch 1/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.9166 | F1: 0.5641
Val   Loss: 0.7564 | F1: 0.7022
✓ Saved best model (F1=0.7022)

Epoch 2/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.5883 | F1: 0.7466
Val   Loss: 0.5833 | F1: 0.7636
✓ Saved best model (F1=0.7636)

Epoch 3/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.4718 | F1: 0.8046
Val   Loss: 0.5099 | F1: 0.8026
✓ Saved best model (F1=0.8026)

Epoch 4/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.3794 | F1: 0.8404
Val   Loss: 0.4626 | F1: 0.8202
✓ Saved best model (F1=0.8202)

Epoch 5/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.3648 | F1: 0.8648
Val   Loss: 0.6295 | F1: 0.7493

Epoch 6/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.2769 | F1: 0.8795
Val   Loss: 0.7100 | F1: 0.7725

Epoch 7/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.2270 | F1: 0.9118
Val   Loss: 0.5801 | F1: 0.8188

Epoch 8/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.2453 | F1: 0.9005
Val   Loss: 0.7118 | F1: 0.7897

Epoch 9/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1799 | F1: 0.9314
Val   Loss: 0.6709 | F1: 0.8012

Epoch 10/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1745 | F1: 0.9292
Val   Loss: 0.6758 | F1: 0.7811

Epoch 11/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1870 | F1: 0.9264
Val   Loss: 0.7525 | F1: 0.7859

Epoch 12/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1425 | F1: 0.9458
Val   Loss: 0.8655 | F1: 0.7647

Epoch 13/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1339 | F1: 0.9523
Val   Loss: 0.8385 | F1: 0.8011

Epoch 14/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1191 | F1: 0.9522
Val   Loss: 0.8428 | F1: 0.7845

Epoch 15/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.1052 | F1: 0.9650
Val   Loss: 0.8113 | F1: 0.7955

Epoch 16/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0768 | F1: 0.9686
Val   Loss: 0.7606 | F1: 0.8068

Epoch 17/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0712 | F1: 0.9736
Val   Loss: 0.9024 | F1: 0.7866

Epoch 18/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0821 | F1: 0.9695
Val   Loss: 0.8379 | F1: 0.8026

Epoch 19/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0871 | F1: 0.9633
Val   Loss: 0.8227 | F1: 0.7950

Epoch 20/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0652 | F1: 0.9751
Val   Loss: 0.7997 | F1: 0.7944

Epoch 21/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0511 | F1: 0.9782
Val   Loss: 0.9532 | F1: 0.7970

Epoch 22/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0714 | F1: 0.9715
Val   Loss: 0.7222 | F1: 0.8198

Epoch 23/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0667 | F1: 0.9712
Val   Loss: 0.9107 | F1: 0.8059

Epoch 24/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0446 | F1: 0.9836
Val   Loss: 0.8054 | F1: 0.8037

Epoch 25/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0367 | F1: 0.9857
Val   Loss: 0.8181 | F1: 0.8042

Epoch 26/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0503 | F1: 0.9847
Val   Loss: 0.8454 | F1: 0.7952

Epoch 27/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0382 | F1: 0.9870
Val   Loss: 1.0070 | F1: 0.7815

Epoch 28/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0460 | F1: 0.9813
Val   Loss: 0.7838 | F1: 0.8019

Epoch 29/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0346 | F1: 0.9853
Val   Loss: 0.8470 | F1: 0.7975

Epoch 30/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0361 | F1: 0.9895
Val   Loss: 0.8019 | F1: 0.7946

Epoch 31/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0203 | F1: 0.9912
Val   Loss: 0.7685 | F1: 0.7911

Epoch 32/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0229 | F1: 0.9889
Val   Loss: 0.8408 | F1: 0.7939

Epoch 33/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0266 | F1: 0.9899
Val   Loss: 0.8453 | F1: 0.8003

Epoch 34/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0325 | F1: 0.9889
Val   Loss: 0.8879 | F1: 0.7999

Epoch 35/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0293 | F1: 0.9924
Val   Loss: 0.8454 | F1: 0.8009

Epoch 36/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0239 | F1: 0.9941
Val   Loss: 0.8337 | F1: 0.8114

Epoch 37/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0180 | F1: 0.9936
Val   Loss: 0.8318 | F1: 0.8057

Epoch 38/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0214 | F1: 0.9932
Val   Loss: 0.8636 | F1: 0.8035

Epoch 39/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0258 | F1: 0.9921
Val   Loss: 0.8512 | F1: 0.7953

Epoch 40/40


Train:   0%|          | 0/74 [00:00<?, ?it/s]

Val:   0%|          | 0/21 [00:00<?, ?it/s]

Train Loss: 0.0281 | F1: 0.9871
Val   Loss: 0.8971 | F1: 0.7963

Best Val F1: 0.8201512857358703


Val:   0%|          | 0/21 [00:00<?, ?it/s]


Classification report (val):
               precision    recall  f1-score   support

   benign (0)       0.82      0.86      0.84       137
malignant (1)       0.75      0.71      0.73        84
   normal (2)       0.90      0.88      0.89       105

     accuracy                           0.83       326
    macro avg       0.82      0.82      0.82       326
 weighted avg       0.83      0.83      0.83       326

Confusion matrix (val):
 [[118  12   7]
 [ 21  60   3]
 [  5   8  92]]


In [None]:
!pip install -q albumentations==1.4.3

import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import balanced_accuracy_score

from sklearn.metrics import f1_score
from tqdm.auto import tqdm

# ----------------- CONFIG -----------------
class SegConfig:
    img_dir   = 'training_images'        # images + *_mask.png
    train_meta = 'metadata_train.csv'    # same as used for classifier
    val_meta   = 'metadata_val.csv'

    output_dir = 'outputs_seg_bm'
    img_size   = 384

    epochs     = 60
    batch_size = 8
    lr         = 1e-4
    weight_decay = 1e-5
    device     = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_workers = 0   # safest in Colab

    seed = 42

seg_cfg = SegConfig()
os.makedirs(seg_cfg.output_dir, exist_ok=True)

# ----------------- SEED -----------------
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(seg_cfg.seed)

# ----------------- TRANSFORMS -----------------
def get_seg_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([
            A.Resize(seg_cfg.img_size, seg_cfg.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05,
                               scale_limit=0.1,
                               rotate_limit=15,
                               border_mode=0,
                               p=0.4),
            A.OneOf([
                A.GaussNoise(var_limit=(5.0, 25.0), p=0.5),
                A.GaussianBlur(blur_limit=(3, 5), p=0.5),
            ], p=0.3),
            A.Normalize(mean=[0.0], std=[1.0]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(seg_cfg.img_size, seg_cfg.img_size),
            A.Normalize(mean=[0.0], std=[1.0]),
            ToTensorV2(),
        ])

# ----------------- DATASET -----------------
class BMSegmentationDataset(Dataset):
    """
    2-class lesion segmentation (B+M vs background).
    Uses LABEL ∈ {0,1} and MASK as foreground.
    """

    def __init__(self, df, img_dir, transform=None):
        # keep only benign + malignant
        df = df[df['LABEL'].isin([0,1])].reset_index(drop=True)
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name  = row['US']
        mask_name = row['MASK']

        img_path  = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.img_dir, mask_name)

        # grayscale image
        image = Image.open(img_path).convert('L')
        image = np.array(image)  # H,W

        # mask: binary 0/1 foreground
        mask = Image.open(mask_path).convert('L')
        mask = np.array(mask)
        mask = (mask > 0).astype('float32')  # H,W

        # Albumentations expects dictionary
        if self.transform:
            aug = self.transform(image=image, mask=mask)
            image = aug['image']        # 1xHxW
            mask  = aug['mask']         # HxW
        else:
            image = torch.from_numpy(image).unsqueeze(0).float()
            mask  = torch.from_numpy(mask).float()

        # make mask shape 1xHxW
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)

        return image, mask

# ----------------- U-NET MODEL -----------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1):
        super().__init__()
        self.inc   = DoubleConv(n_channels, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 512))

        self.up1   = nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(512 + 512, 256)
        self.up2   = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256 + 256, 128)
        self.up3   = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128 + 128, 64)
        self.up4   = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(64 + 64, 64)

        self.outc  = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.conv1(x)
        x = self.up2(x)
        x = torch.cat([x, x3], dim=1)
        x = self.conv2(x)
        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.conv3(x)
        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.conv4(x)
        logits = self.outc(x)  # Bx1xHxW
        return logits

# ----------------- LOSSES & METRICS -----------------
class DiceBCELoss(nn.Module):
    """Dice + BCE for binary segmentation."""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.smooth = smooth

    def forward(self, logits, targets):
        bce = self.bce(logits, targets)

        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)

        intersection = (probs_flat * targets_flat).sum()
        dice = (2. * intersection + self.smooth) / (
            probs_flat.sum() + targets_flat.sum() + self.smooth
        )
        dice_loss = 1 - dice
        return bce + dice_loss

def dice_coeff(preds, targets, threshold=0.5, eps=1e-7):
    probs = torch.sigmoid(preds)
    preds_bin = (probs > threshold).float()

    intersection = (preds_bin * targets).sum(dim=(1,2,3))
    union = preds_bin.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3))
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean().item()

# ----------------- TRAIN / VALID -----------------
def train_seg_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    n_batches = 0

    pbar = tqdm(loader, desc='Seg Train', leave=False)
    for images, masks in pbar:
        images = images.to(device)
        masks  = masks.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_dice += dice_coeff(logits.detach(), masks.detach())
        n_batches += 1

        pbar.set_postfix(loss=running_loss/n_batches,
                         dice=running_dice/n_batches)

    return running_loss/n_batches, running_dice/n_batches

def valid_seg_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_dice = 0.0
    n_batches = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc='Seg Val', leave=False)
        for images, masks in pbar:
            images = images.to(device)
            masks  = masks.to(device)

            logits = model(images)
            loss = criterion(logits, masks)

            running_loss += loss.item()
            running_dice += dice_coeff(logits, masks)
            n_batches += 1

    return running_loss/n_batches, running_dice/n_batches

# ----------------- MAIN -----------------
def main():
    print("Device:", seg_cfg.device)

    train_df = pd.read_csv(seg_cfg.train_meta)
    val_df   = pd.read_csv(seg_cfg.val_meta)

    print("Original train size:", len(train_df), "val size:", len(val_df))
    print("Train label counts:\n", train_df['LABEL'].value_counts())
    print("Val label counts:\n",   val_df['LABEL'].value_counts())

    # Datasets only on B/M (0,1)
    train_ds = BMSegmentationDataset(train_df, seg_cfg.img_dir,
                                     transform=get_seg_transforms('train'))
    val_ds   = BMSegmentationDataset(val_df, seg_cfg.img_dir,
                                     transform=get_seg_transforms('val'))

    print("Seg train size (B+M):", len(train_ds), "Seg val size (B+M):", len(val_ds))

    train_loader = DataLoader(
        train_ds, batch_size=seg_cfg.batch_size, shuffle=True,
        num_workers=seg_cfg.num_workers, pin_memory=False
    )
    val_loader = DataLoader(
        val_ds, batch_size=seg_cfg.batch_size, shuffle=False,
        num_workers=seg_cfg.num_workers, pin_memory=False
    )

    model = UNet(n_channels=1, n_classes=1).to(seg_cfg.device)
    criterion = DiceBCELoss()
    optimizer = optim.AdamW(model.parameters(),
                            lr=seg_cfg.lr,
                            weight_decay=seg_cfg.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=seg_cfg.epochs, eta_min=1e-6
    )

    best_dice = 0.0
    best_path = os.path.join(seg_cfg.output_dir, 'best_seg_bm.pth')
    history = {'train_loss': [], 'val_loss': [],
               'train_dice': [], 'val_dice': []}

    for epoch in range(seg_cfg.epochs):
        print(f"\nEpoch {epoch+1}/{seg_cfg.epochs}")
        tr_loss, tr_dice = train_seg_epoch(model, train_loader, criterion, optimizer, seg_cfg.device)
        val_loss, val_dice = valid_seg_epoch(model, val_loader, criterion, seg_cfg.device)

        scheduler.step()

        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_dice'].append(tr_dice)
        history['val_dice'].append(val_dice)

        print(f"Train Loss: {tr_loss:.4f} | Dice: {tr_dice:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Dice: {val_dice:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
                'best_dice': best_dice,
            }, best_path)
            print(f"✓ Saved best model (Dice={best_dice:.4f})")

    print("\nBest Val Dice:", best_dice)

    # save history for plotting curves
    hist_df = pd.DataFrame(history)
    hist_df.to_csv(os.path.join(seg_cfg.output_dir,
                                'seg_bm_training_history.csv'),
                   index=False)

if __name__ == "__main__":
    main()


Device: cuda
Original train size: 1177 val size: 326
Train label counts:
 LABEL
0    542
2    355
1    280
Name: count, dtype: int64
Val label counts:
 LABEL
0    137
2    105
1     84
Name: count, dtype: int64
Seg train size (B+M): 822 Seg val size (B+M): 221

Epoch 1/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.4457 | Dice: 0.0804
Val   Loss: 1.3593 | Dice: 0.0752
✓ Saved best model (Dice=0.0752)

Epoch 2/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.3165 | Dice: 0.0445
Val   Loss: 1.2987 | Dice: 0.0958
✓ Saved best model (Dice=0.0958)

Epoch 3/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.2476 | Dice: 0.2196
Val   Loss: 1.3162 | Dice: 0.2693
✓ Saved best model (Dice=0.2693)

Epoch 4/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.1523 | Dice: 0.3666
Val   Loss: 1.1339 | Dice: 0.4317
✓ Saved best model (Dice=0.4317)

Epoch 5/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.0741 | Dice: 0.4192
Val   Loss: 1.0667 | Dice: 0.4022

Epoch 6/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 1.0162 | Dice: 0.4512
Val   Loss: 1.0264 | Dice: 0.4398
✓ Saved best model (Dice=0.4398)

Epoch 7/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.9516 | Dice: 0.4806
Val   Loss: 0.9117 | Dice: 0.5092
✓ Saved best model (Dice=0.5092)

Epoch 8/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.9152 | Dice: 0.4899
Val   Loss: 0.8992 | Dice: 0.4868

Epoch 9/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.8632 | Dice: 0.5168
Val   Loss: 0.9728 | Dice: 0.4558

Epoch 10/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.8370 | Dice: 0.5192
Val   Loss: 0.7966 | Dice: 0.5789
✓ Saved best model (Dice=0.5789)

Epoch 11/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.7876 | Dice: 0.5442
Val   Loss: 0.7681 | Dice: 0.5598

Epoch 12/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.7614 | Dice: 0.5529
Val   Loss: 0.7535 | Dice: 0.5900
✓ Saved best model (Dice=0.5900)

Epoch 13/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.7507 | Dice: 0.5463
Val   Loss: 0.7470 | Dice: 0.5514

Epoch 14/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.7122 | Dice: 0.5639
Val   Loss: 0.6765 | Dice: 0.5825

Epoch 15/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6998 | Dice: 0.5700
Val   Loss: 0.7008 | Dice: 0.5872

Epoch 16/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6815 | Dice: 0.5754
Val   Loss: 0.6086 | Dice: 0.6179
✓ Saved best model (Dice=0.6179)

Epoch 17/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6596 | Dice: 0.5830
Val   Loss: 0.6162 | Dice: 0.6343
✓ Saved best model (Dice=0.6343)

Epoch 18/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6354 | Dice: 0.5999
Val   Loss: 0.6056 | Dice: 0.6188

Epoch 19/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6259 | Dice: 0.6037
Val   Loss: 0.6266 | Dice: 0.6030

Epoch 20/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6064 | Dice: 0.6061
Val   Loss: 0.5920 | Dice: 0.6222

Epoch 21/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5966 | Dice: 0.6165
Val   Loss: 0.5889 | Dice: 0.6180

Epoch 22/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.6007 | Dice: 0.6142
Val   Loss: 0.5977 | Dice: 0.6177

Epoch 23/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5881 | Dice: 0.6215
Val   Loss: 0.6241 | Dice: 0.5939

Epoch 24/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5596 | Dice: 0.6311
Val   Loss: 0.5670 | Dice: 0.6286

Epoch 25/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5593 | Dice: 0.6269
Val   Loss: 0.5644 | Dice: 0.6339

Epoch 26/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5673 | Dice: 0.6302
Val   Loss: 0.5457 | Dice: 0.6360
✓ Saved best model (Dice=0.6360)

Epoch 27/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5421 | Dice: 0.6380
Val   Loss: 0.5214 | Dice: 0.6529
✓ Saved best model (Dice=0.6529)

Epoch 28/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5297 | Dice: 0.6455
Val   Loss: 0.5251 | Dice: 0.6447

Epoch 29/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5238 | Dice: 0.6569
Val   Loss: 0.5165 | Dice: 0.6480

Epoch 30/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5215 | Dice: 0.6530
Val   Loss: 0.5294 | Dice: 0.6475

Epoch 31/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5201 | Dice: 0.6515
Val   Loss: 0.4925 | Dice: 0.6781
✓ Saved best model (Dice=0.6781)

Epoch 32/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5186 | Dice: 0.6512
Val   Loss: 0.5006 | Dice: 0.6705

Epoch 33/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.5049 | Dice: 0.6628
Val   Loss: 0.5051 | Dice: 0.6462

Epoch 34/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4979 | Dice: 0.6655
Val   Loss: 0.4869 | Dice: 0.6789
✓ Saved best model (Dice=0.6789)

Epoch 35/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4941 | Dice: 0.6662
Val   Loss: 0.4834 | Dice: 0.6796
✓ Saved best model (Dice=0.6796)

Epoch 36/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4941 | Dice: 0.6658
Val   Loss: 0.5040 | Dice: 0.6566

Epoch 37/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4845 | Dice: 0.6706
Val   Loss: 0.4709 | Dice: 0.6748

Epoch 38/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4632 | Dice: 0.6861
Val   Loss: 0.4759 | Dice: 0.6806
✓ Saved best model (Dice=0.6806)

Epoch 39/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4704 | Dice: 0.6783
Val   Loss: 0.4834 | Dice: 0.6708

Epoch 40/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4737 | Dice: 0.6730
Val   Loss: 0.4988 | Dice: 0.6625

Epoch 41/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4831 | Dice: 0.6813
Val   Loss: 0.4731 | Dice: 0.6656

Epoch 42/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4939 | Dice: 0.6785
Val   Loss: 0.4676 | Dice: 0.6823
✓ Saved best model (Dice=0.6823)

Epoch 43/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4658 | Dice: 0.6868
Val   Loss: 0.4719 | Dice: 0.6679

Epoch 44/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4504 | Dice: 0.6926
Val   Loss: 0.4732 | Dice: 0.6593

Epoch 45/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4556 | Dice: 0.6946
Val   Loss: 0.4985 | Dice: 0.6536

Epoch 46/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4503 | Dice: 0.6982
Val   Loss: 0.4669 | Dice: 0.6747

Epoch 47/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4382 | Dice: 0.7049
Val   Loss: 0.4598 | Dice: 0.6844
✓ Saved best model (Dice=0.6844)

Epoch 48/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4421 | Dice: 0.7020
Val   Loss: 0.4571 | Dice: 0.6850
✓ Saved best model (Dice=0.6850)

Epoch 49/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4360 | Dice: 0.7005
Val   Loss: 0.4654 | Dice: 0.6772

Epoch 50/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4404 | Dice: 0.7018
Val   Loss: 0.4608 | Dice: 0.6846

Epoch 51/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4245 | Dice: 0.7113
Val   Loss: 0.4667 | Dice: 0.6713

Epoch 52/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4308 | Dice: 0.7164
Val   Loss: 0.4592 | Dice: 0.6827

Epoch 53/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4236 | Dice: 0.7144
Val   Loss: 0.4577 | Dice: 0.6832

Epoch 54/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4282 | Dice: 0.7155
Val   Loss: 0.4524 | Dice: 0.6856
✓ Saved best model (Dice=0.6856)

Epoch 55/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4339 | Dice: 0.7046
Val   Loss: 0.4529 | Dice: 0.6858
✓ Saved best model (Dice=0.6858)

Epoch 56/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4276 | Dice: 0.7149
Val   Loss: 0.4595 | Dice: 0.6798

Epoch 57/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4318 | Dice: 0.7105
Val   Loss: 0.4564 | Dice: 0.6800

Epoch 58/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4164 | Dice: 0.7252
Val   Loss: 0.4649 | Dice: 0.6754

Epoch 59/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4177 | Dice: 0.7248
Val   Loss: 0.4559 | Dice: 0.6823

Epoch 60/60


Seg Train:   0%|          | 0/103 [00:00<?, ?it/s]

Seg Val:   0%|          | 0/28 [00:00<?, ?it/s]

Train Loss: 0.4315 | Dice: 0.7143
Val   Loss: 0.4511 | Dice: 0.6851

Best Val Dice: 0.6857801271336419


In [None]:
!pip install -q albumentations==1.4.3 segmentation-models-pytorch

import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import f1_score, classification_report, confusion_matrix, balanced_accuracy_score
from tqdm.auto import tqdm
import segmentation_models_pytorch as smp
from sklearn.metrics import balanced_accuracy_score


# ----------------- CONFIG -----------------
class Pipeline2Config:
    img_dir   = 'training_images'
    train_meta = 'metadata_train.csv'
    val_meta   = 'metadata_val.csv'

    output_dir = 'outputs_pipeline2'
    img_size   = 384
    seg_epochs = 60
    cls_epochs = 40
    batch_size = 8
    lr         = 1e-4
    weight_decay = 1e-5
    device     = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_workers = 0

    seed = 42

cfg = Pipeline2Config()
os.makedirs(cfg.output_dir, exist_ok=True)

# ----------------- SEED -----------------
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

# ----------------- 3-CLASS SEGMENTATION DATASET -----------------
class NBM3ClassSegDataset(Dataset):
    """
    3-class segmentation: N(background)=0, B=1, M=2
    """
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name  = row['US']
        mask_name = row['MASK']
        label     = int(row['LABEL'])  # 0=B, 1=M, 2=N

        img_path  = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.img_dir, mask_name)

        # image
        image = Image.open(img_path).convert('L')
        image = np.array(image).astype(np.float32)  # H,W

        # 3-class mask: background=0, lesion=label (B=1 or M=2)
        mask_gt = Image.open(mask_path).convert('L')
        mask_gt = np.array(mask_gt)
        mask_3class = np.zeros_like(mask_gt)
        lesion_pixels = mask_gt > 0
        mask_3class[lesion_pixels] = label  # B=1 or M=2 pixels

        if self.transform:
            aug = self.transform(image=image, mask=mask_3class)
            image = aug['image']
            mask  = aug['mask'].long()
        else:
            image = torch.from_numpy(image).unsqueeze(0)
            mask  = torch.from_numpy(mask_3class).long()

        return image, mask

# ----------------- MASKED B vs M CLASSIFIER DATASET -----------------
class MaskedBMClassifierDataset(Dataset):
    """
    B vs M classification using ground-truth masks during training.
    At test time: will use predicted 3-class seg masks.
    """
    def __init__(self, df, img_dir, transform=None, use_gt_mask=True):
        # only B/M cases
        self.df = df[df['LABEL'].isin([0,1])].reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.use_gt_mask = use_gt_mask

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name  = row['US']
        mask_name = row['MASK']
        label     = int(row['LABEL'])  # 0=B, 1=M

        img_path  = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.img_dir, mask_name)

        # load image
        image = Image.open(img_path).convert('L')
        image = np.array(image).astype(np.float32)

        # create masked image: multiply by lesion mask
        mask = Image.open(mask_path).convert('L')
        mask = np.array(mask) > 0
        masked_image = image * mask.astype(np.float32)

        # repeat to 3 channels
        masked_image = np.stack([masked_image]*3, axis=2)

        if self.transform:
            aug = self.transform(image=masked_image)
            masked_image = aug['image']

        return masked_image, label

# ----------------- TRANSFORMS -----------------
def get_seg_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([
            A.Resize(cfg.img_size, cfg.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.4),
            A.Normalize(mean=[0.0], std=[1.0]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(cfg.img_size, cfg.img_size),
            A.Normalize(mean=[0.0], std=[1.0]),
            ToTensorV2(),
        ])

def get_cls_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([
            A.Resize(cfg.img_size, cfg.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.RandomRotate90(p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(cfg.img_size, cfg.img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

# ----------------- 3-CLASS SEGMENTATION MODEL -----------------
class NBM3ClassUNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=3):
        super().__init__()
        self.model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=n_channels,
            classes=n_classes,
            activation=None  # logits
        )

    def forward(self, x):
        return self.model(x)

# ----------------- B vs M CLASSIFIER -----------------
class BMClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
            'tf_efficientnetv2_s',
            pretrained=True,
            num_classes=0
        )
        in_features = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 2)  # B vs M
        )

    def forward(self, x):
        feats = self.backbone(x)
        return self.classifier(feats)

# ----------------- LOSSES & METRICS -----------------
def multiclass_dice_loss(logits, targets, smooth=1e-6):
    """Multiclass Dice loss"""
    probs = torch.softmax(logits, dim=1)
    num_classes = logits.shape[1]

    dice_loss = 0
    for c in range(num_classes):
        pred_c = probs[:, c].unsqueeze(1)
        target_c = (targets == c).float().unsqueeze(1)

        intersection = (pred_c * target_c).sum(dim=(2,3))
        union = pred_c.sum(dim=(2,3)) + target_c.sum(dim=(2,3))
        dice_c = (2 * intersection + smooth) / (union + smooth)
        dice_loss += 1 - dice_c.mean()

    return dice_loss / num_classes

def mean_dice_3class(logits, targets):
    probs = torch.softmax(logits, dim=1)
    num_classes = logits.shape[1]
    dice_scores = []

    for c in range(num_classes):
        pred_c = (probs[:, c] > 0.5).float()
        target_c = (targets == c).float()
        intersection = (pred_c * target_c).sum(dim=(2,3))
        union = pred_c.sum(dim=(2,3)) + target_c.sum(dim=(2,3))
        dice_c = (2 * intersection + 1e-7) / (union + 1e-7)
        dice_scores.append(dice_c.mean())

    return torch.stack(dice_scores).mean().item()

# ----------------- PIPELINE 2 INFERENCE -----------------
class Pipeline2Inference:
    def __init__(self, seg_model_path, cls_model_path, device):
        self.device = device
        self.seg_model = NBM3ClassUNet().to(device)
        self.seg_model.load_state_dict(torch.load(seg_model_path, map_location=device)['model_state_dict'])
        self.seg_model.eval()

        self.cls_model = BMClassifier().to(device)
        self.cls_model.load_state_dict(torch.load(cls_model_path, map_location=device)['model_state_dict'])
        self.cls_model.eval()

    def predict_image_label(self, image_tensor):  # 1x3xHxW
        with torch.no_grad():
            image_tensor = image_tensor.unsqueeze(0).to(self.device)

            # Step 1: 3-class segmentation
            seg_logits = self.seg_model(image_tensor[:, :1])  # take 1st channel
            seg_probs = torch.softmax(seg_logits, dim=1)

            # Check if any B/M lesion pixels
            b_prob = seg_probs[0, 1].sum()  # sum all B pixels
            m_prob = seg_probs[0, 2].sum()  # sum all M pixels

            if b_prob + m_prob < 0.1 * seg_probs[0, 0].sum():  # mostly background
                return 2  # NORMAL

            # Step 2: create masked lesion image
            lesion_mask = (seg_probs[0, 1] + seg_probs[0, 2] > 0.3).float()
            masked_img = image_tensor[0] * lesion_mask.unsqueeze(0)

            # Step 3: B vs M classification
            cls_logits = self.cls_model(masked_img.unsqueeze(0))
            pred_cls = cls_logits.argmax(1).item()

            return pred_cls.item()  # 0=B, 1=M

# ----------------- MAIN TRAINING -----------------
def train_3class_seg():
    print("=== Training 3-class N/B/M Segmentation ===")

    train_df = pd.read_csv(cfg.train_meta)
    val_df   = pd.read_csv(cfg.val_meta)

    train_ds = NBM3ClassSegDataset(train_df, cfg.img_dir, get_seg_transforms('train'))
    val_ds   = NBM3ClassSegDataset(val_df, cfg.img_dir, get_seg_transforms('val'))

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

    model = NBM3ClassUNet().to(cfg.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.seg_epochs)

    best_dice = 0
    history = {'train_loss': [], 'val_loss': [], 'val_dice': []}

    for epoch in range(cfg.seg_epochs):
        model.train()
        train_loss = 0
        for imgs, masks in tqdm(train_loader, desc='Train'):
            imgs, masks = imgs.to(cfg.device), masks.to(cfg.device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # validation
        model.eval()
        val_loss, val_dice_total = 0, 0
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc='Val'):
                imgs, masks = imgs.to(cfg.device), masks.to(cfg.device)
                logits = model(imgs)
                loss = criterion(logits, masks)
                val_loss += loss.item()
                val_dice_total += mean_dice_3class(logits, masks)

        val_dice = val_dice_total / len(val_loader)
        scheduler.step()

        print(f"Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, Val Dice: {val_dice:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save({'model_state_dict': model.state_dict(), 'best_dice': best_dice},
                      os.path.join(cfg.output_dir, 'best_3class_seg.pth'))

    return os.path.join(cfg.output_dir, 'best_3class_seg.pth')

def train_bm_classifier(seg_model_path):
    print("=== Training B vs M masked classifier ===")

    train_df = pd.read_csv(cfg.train_meta)
    val_df   = pd.read_csv(cfg.val_meta)

    train_ds = MaskedBMClassifierDataset(train_df, cfg.img_dir, get_cls_transforms('train'))
    val_ds   = MaskedBMClassifierDataset(val_df, cfg.img_dir, get_cls_transforms('val'))

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0)

    model = BMClassifier().to(cfg.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.cls_epochs)

    best_f1 = 0
    for epoch in range(cfg.cls_epochs):
        model.train()
        train_loss = 0
        train_preds, train_labels = [], []
        for imgs, labels in tqdm(train_loader, desc='Train'):
            imgs, labels = imgs.to(cfg.device), labels.to(cfg.device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            preds = logits.argmax(1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        # validation
        model.eval()
        val_loss, val_preds, val_labels = 0, [], []
        with torch.no_grad():
            for imgs, labels in tqdm(val_loader, desc='Val'):
                imgs, labels = imgs.to(cfg.device), labels.to(cfg.device)
                logits = model(imgs)
                loss = criterion(logits, labels)
                val_loss += loss.item()
                preds = logits.argmax(1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_f1 = f1_score(val_labels, val_preds, average='macro')
        scheduler.step()

        print(f"Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, Val F1: {val_f1:.4f}")

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save({'model_state_dict': model.state_dict(), 'best_f1': best_f1},
                      os.path.join(cfg.output_dir, 'best_bm_classifier.pth'))

    return os.path.join(cfg.output_dir, 'best_bm_classifier.pth')

# ----------------- EVALUATION & COMPARISON -----------------
def evaluate_pipeline2(seg_model_path, cls_model_path, val_df, img_dir):
    """Run full pipeline 2 on validation set and return image-level predictions"""
    print("=== Evaluating Pipeline 2 on validation set ===")

    infer = Pipeline2Inference(seg_model_path, cls_model_path, cfg.device)

    val_ds = BreastUSDataset(val_df, img_dir, transform=get_transforms('val'))
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)

    predictions = []
    true_labels = []

    infer.seg_model.eval()
    infer.cls_model.eval()

    with torch.no_grad():
        for image, true_label in tqdm(val_loader, desc='Pipeline 2 inference'):
            pred_label = infer.predict_image_label(image)
            predictions.append(pred_label)
            true_labels.append(true_label.item())

    # Compute metrics
    balanced_acc = balanced_accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')
    cm = confusion_matrix(true_labels, predictions, labels=[0,1,2])

    print(f"Pipeline 2 - Balanced Accuracy: {balanced_acc:.4f}")
    print(f"Pipeline 2 - Macro F1: {f1:.4f}")
    print("Pipeline 2 - Confusion Matrix:")
    print(cm)

    return predictions, true_labels, balanced_acc, f1, cm

# ----------------- MAIN EXECUTION -----------------
if __name__ == "__main__":
    print("Pipeline 2: 3-class Segmentation → B vs M Classifier")

    # Step 1: Train 3-class segmentation
    seg_model_path = train_3class_seg()

    # Step 2: Train B vs M classifier on masked lesions
    cls_model_path = train_bm_classifier(seg_model_path)

    # Step 3: Evaluate full pipeline on validation set
    val_df = pd.read_csv(cfg.val_meta)
    predictions, true_labels, bal_acc, f1, cm = evaluate_pipeline2(
        seg_model_path, cls_model_path, val_df, cfg.img_dir
    )

    # Save pipeline results for report comparison
    results = {
        'seg_model_path': seg_model_path,
        'cls_model_path': cls_model_path,
        'balanced_accuracy': bal_acc,
        'macro_f1': f1,
        'confusion_matrix': cm.tolist(),
        'predictions': predictions,
        'true_labels': true_labels
    }
    torch.save(results, os.path.join(cfg.output_dir, 'pipeline2_results.pth'))

    print("\n✓ Pipeline 2 complete!")
    print(f"Models saved in: {cfg.output_dir}")
    print("Ready for report comparison with Pipeline 1!")



Pipeline 2: 3-class Segmentation → B vs M Classifier
=== Training 3-class N/B/M Segmentation ===


Train:   0%|          | 0/148 [00:00<?, ?it/s]

Val:   0%|          | 0/41 [00:00<?, ?it/s]

IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)