In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from sklearn.metrics import accuracy_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# ✅ Dataset Paths
TRAIN_PATH = r"E:\PFE\Flower code\data original\DATA\Mass\Train"
TEST_PATH = r"E:\PFE\Flower code\data original\DATA\Mass\Test"

# ✅ Collect image paths and labels
def collect_data(root_path):
    data = []
    for label_name, label_id in [("BENIGN", 0), ("MALIGNANT", 1)]:
        class_path = os.path.join(root_path, label_name)
        files = sorted(os.listdir(class_path))
        for i in range(0, len(files), 2):  # alternate: image, mask
            img_file = files[i]
            if "MASK" in img_file:
                continue
            data.append([os.path.join(class_path, img_file), label_id])
    return pd.DataFrame(data, columns=["image", "label"])

train_df = collect_data(TRAIN_PATH)
test_df = collect_data(TEST_PATH)

# ✅ Merge and Stratified Manual Split
df = pd.concat([train_df, test_df]).sample(frac=1, random_state=42).reset_index(drop=True)
train_df, valid_df = pd.DataFrame(), pd.DataFrame()
for label in df['label'].unique():
    class_df = df[df['label'] == label]
    split = int(0.8 * len(class_df))
    train_df = pd.concat([train_df, class_df.iloc[:split]])
    valid_df = pd.concat([valid_df, class_df.iloc[split:]])
train_df, valid_df = train_df.reset_index(drop=True), valid_df.reset_index(drop=True)

# ✅ Dataset Class
class MassDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["image"]).convert("RGB")
        label = row["label"]
        if self.transform:
            image = self.transform(image)
        return image, label

# ✅ Transformations
train_transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
valid_transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ✅ Loaders
train_dataset = MassDataset(train_df, transform=train_transform)
valid_dataset = MassDataset(valid_df, transform=valid_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16)

# ✅ Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
model.to(device)

# ✅ Training Setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 95
print("start training")
# ✅ Training Loop
for epoch in range(epochs):
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_preds.extend(preds.argmax(1).cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

    train_acc = accuracy_score(train_labels, train_preds)
    train_recall = recall_score(train_labels, train_preds)

    # ✅ Validation
    model.eval()
    valid_loss = 0
    valid_preds, valid_labels = [], []
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = criterion(preds, labels)

            valid_loss += loss.item()
            valid_preds.extend(preds.argmax(1).cpu().numpy())
            valid_labels.extend(labels.cpu().numpy())

    valid_acc = accuracy_score(valid_labels, valid_preds)
    valid_recall = recall_score(valid_labels, valid_preds)

    print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.4f} | Valid Loss: {valid_loss/len(valid_loader):.4f} | "
          f"Train Acc: {train_acc:.4f} | Valid Acc: {valid_acc:.4f} | Train Recall: {train_recall:.4f} | Valid Recall: {valid_recall:.4f}")

# ✅ Save model
torch.save(model.state_dict(), "cbis_ddsm_resnet_classifier.pth")

# ✅ Confusion Matrix
cm = confusion_matrix(valid_labels, valid_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Benign", "Malignant"], yticklabels=["Benign", "Malignant"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()




start training
