In [2]:
import os
import torch
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch import nn
from torch.optim import AdamW
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import csv

# ---------------------CONFIG-----------------------
train_dir = "model"
val_dir = "validare"
batch_size = 16
num_epochs = 10
early_stop_patience = 3
threshold = 0.3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------TRANSFORMARI-------------------
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ------------------------DATE-----------------------
train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)

targets = train_dataset.targets
class_counts = [targets.count(0), targets.count(1)]
print("Distributie clase train:", class_counts)

weights = [1.0 / class_counts[t] for t in targets]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# ------------------------MODEL----------------------
model = models.resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False  # îngheață tot

num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # doar fc va fi antrenat
model.fc.requires_grad = True
model.to(device)

# Încarcă model salvat (dacă există)
model_path = "resnet50_mamografie_best_now/model.pt"
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f"Model pre-antrenat încărcat din: {model_path}")

# -------------------OPTIMIZATOR, LOSS, SCHEDULER-----
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1.0, class_counts[0] / class_counts[1]], dtype=torch.float).to(device))
optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

# -----------------------TRAINING--------------------
best_recall = 0
epochs_no_improve = 0

with open("log_finetune.csv", mode="w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Epoch", "Precision", "Recall", "F1", "LR"])

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            progress_bar.set_postfix({
                "Loss": f"{running_loss / total:.4f}",
                "Acc": f"{(correct / total) * 100:.2f}%"
            })

        # ---------------------VALIDARE----------------------
        model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                probs = torch.softmax(outputs, dim=1)[:, 1]
                preds = (probs > threshold).long()

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step(recall)

        print(f"\nEpoch {epoch+1} → Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f} | LR: {current_lr}")
        print(classification_report(y_true, y_pred, target_names=["Benign", "Malign"]))

        # Matrice confuzie
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(5, 4))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Benign", "Malign"], yticklabels=["Benign", "Malign"])
        plt.xlabel("Predictii")
        plt.ylabel("Etichete reale")
        plt.title(f"Matricea de Confuzie - Epoca {epoch+1}")
        os.makedirs("confusion_matrices", exist_ok=True)
        plt.savefig(f"confusion_matrices/conf_matrix_epoch_{epoch+1}.png")
        plt.close()

        writer.writerow([epoch+1, precision, recall, f1, current_lr])

        # ------------------EARLY STOPPING + SAVE------------
        if recall > best_recall:
            best_recall = recall
            epochs_no_improve = 0
            output_dir = "resnet50_mamografie_model_best_now"
            os.makedirs(output_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
            print(f"Model salvat în: {output_dir}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= early_stop_patience:
                print(f"Early stopping activat. Recall maxim: {best_recall:.4f}")
                break


Distributie clase train: [348, 123]




Model pre-antrenat încărcat din: resnet50_mamografie_best_now/model.pt


Epoch 1/10: 100%|██████████| 30/30 [00:20<00:00,  1.45it/s, Loss=0.0225, Acc=79.62%]



Epoch 1 → Precision: 0.5000 | Recall: 0.5484 | F1: 0.5231 | LR: 0.0001
              precision    recall  f1-score   support

      Benign       0.84      0.81      0.82        88
      Malign       0.50      0.55      0.52        31

    accuracy                           0.74       119
   macro avg       0.67      0.68      0.67       119
weighted avg       0.75      0.74      0.74       119

Model salvat în: resnet50_mamografie_model_best_now


Epoch 2/10: 100%|██████████| 30/30 [00:29<00:00,  1.01it/s, Loss=0.0231, Acc=78.56%]



Epoch 2 → Precision: 0.4419 | Recall: 0.6129 | F1: 0.5135 | LR: 0.0001
              precision    recall  f1-score   support

      Benign       0.84      0.73      0.78        88
      Malign       0.44      0.61      0.51        31

    accuracy                           0.70       119
   macro avg       0.64      0.67      0.65       119
weighted avg       0.74      0.70      0.71       119

Model salvat în: resnet50_mamografie_model_best_now


Epoch 3/10: 100%|██████████| 30/30 [00:28<00:00,  1.06it/s, Loss=0.0189, Acc=83.01%]



Epoch 3 → Precision: 0.4043 | Recall: 0.6129 | F1: 0.4872 | LR: 0.0001
              precision    recall  f1-score   support

      Benign       0.83      0.68      0.75        88
      Malign       0.40      0.61      0.49        31

    accuracy                           0.66       119
   macro avg       0.62      0.65      0.62       119
weighted avg       0.72      0.66      0.68       119


Epoch 4/10: 100%|██████████| 30/30 [00:27<00:00,  1.09it/s, Loss=0.0214, Acc=79.19%]



Epoch 4 → Precision: 0.4524 | Recall: 0.6129 | F1: 0.5205 | LR: 0.0001
              precision    recall  f1-score   support

      Benign       0.84      0.74      0.79        88
      Malign       0.45      0.61      0.52        31

    accuracy                           0.71       119
   macro avg       0.65      0.68      0.65       119
weighted avg       0.74      0.71      0.72       119


Epoch 5/10: 100%|██████████| 30/30 [00:28<00:00,  1.04it/s, Loss=0.0235, Acc=80.68%]



Epoch 5 → Precision: 0.3654 | Recall: 0.6129 | F1: 0.4578 | LR: 0.0001
              precision    recall  f1-score   support

      Benign       0.82      0.62      0.71        88
      Malign       0.37      0.61      0.46        31

    accuracy                           0.62       119
   macro avg       0.59      0.62      0.58       119
weighted avg       0.70      0.62      0.64       119

Early stopping activat. Recall maxim: 0.6129
