In [11]:
import os
import torch
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch import nn
from torch.optim import AdamW
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from tqdm import tqdm
from collections import Counter
import csv
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
# ---------------------CONFIG-----------------------
train_dir = "model"
val_dir = "validare"
batch_size = 16
num_epochs = 10
early_stop_patience = 3
threshold = 0.3  #prag
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#-------------------TRANSFORMARI----------------------------
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

#------------------------DATE----------------------------------
train_dataset = datasets.ImageFolder(train_dir, transform=transform_train)
val_dataset = datasets.ImageFolder(val_dir, transform=transform_val)

targets = train_dataset.targets
class_counts = [targets.count(0), targets.count(1)]
print("Distributie clase train:", class_counts)

weights = [1.0 / class_counts[t] for t in targets]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

#------------------------MODEL--------------------------------
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
model.to(device)

class FocalLoss(nn.Module):
    def __init__(self, alpha=2.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss
        return loss.mean()

loss_fn = FocalLoss(alpha=2.0, gamma=2.0)
optimizer = AdamW(model.parameters(), lr=1e-4)
#---------------------------TRAIN-------------------------
best_recall = 0
epochs_no_improve = 0

with open("fisier_log.csv", mode="w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Epoch", "Precision", "Recall", "F1"])

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            progress_bar.set_postfix({
                "Loss": f"{running_loss / total:.4f}",
                "Acc": f"{(correct / total) * 100:.2f}%"
            })

        print(f"Epoch {epoch+1}: Train Loss = {running_loss / total:.4f}")

        #---------------------VALIDARE--------------------------
        model.eval()
        y_true, y_pred = [], []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)

                probs = torch.softmax(outputs, dim=1)[:, 1]  # probabilitate malign
                preds = (probs > threshold).long()

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

        precision = precision_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)

        print(f"Epoch {epoch+1} → Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")
        print("Etichete reale:", Counter(y_true))
        print("Predictii:", Counter(y_pred))
        print(classification_report(y_true, y_pred, target_names=["Benign", "Malign"]))
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(5, 4))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Benign", "Malign"], yticklabels=["Benign", "Malign"])
        plt.xlabel("Predictii")
        plt.ylabel("Etichete reale")
        plt.title(f"Matricea de Confuzie - Epoca {epoch+1}")
        os.makedirs("confusion_matrices", exist_ok=True)
        plt.savefig(f"confusion_matrices/conf_matrix_epoch_{epoch+1}.png")
        plt.close()
        writer.writerow([epoch+1, precision, recall, f1])

        # Early stopping
        if recall > best_recall:
            best_recall = recall
            epochs_no_improve = 0
            output_dir = "resnet50_mamografie_model"
            os.makedirs(output_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
            print(f"Model salvat in: {output_dir}")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= early_stop_patience:
                print(f"Early stopping activat. Recall maxim: {best_recall:.4f}")
                break


Distributie clase train: [348, 123]


Epoch 1/10: 100%|██████████| 30/30 [01:09<00:00,  2.32s/it, Loss=0.0177, Acc=73.46%]


Epoch 1: Train Loss = 0.0177
Epoch 1 → Precision: 0.2700 | Recall: 0.8710 | F1: 0.4122
Etichete reale: Counter({np.int64(0): 88, np.int64(1): 31})
Predictii: Counter({np.int64(1): 100, np.int64(0): 19})
              precision    recall  f1-score   support

      Benign       0.79      0.17      0.28        88
      Malign       0.27      0.87      0.41        31

    accuracy                           0.35       119
   macro avg       0.53      0.52      0.35       119
weighted avg       0.65      0.35      0.31       119

Model salvat in: resnet50_mamografie_model


Epoch 2/10: 100%|██████████| 30/30 [01:10<00:00,  2.36s/it, Loss=0.0188, Acc=68.15%]


Epoch 2: Train Loss = 0.0188
Epoch 2 → Precision: 0.2609 | Recall: 0.9677 | F1: 0.4110
Etichete reale: Counter({np.int64(0): 88, np.int64(1): 31})
Predictii: Counter({np.int64(1): 115, np.int64(0): 4})
              precision    recall  f1-score   support

      Benign       0.75      0.03      0.07        88
      Malign       0.26      0.97      0.41        31

    accuracy                           0.28       119
   macro avg       0.51      0.50      0.24       119
weighted avg       0.62      0.28      0.16       119

Model salvat in: resnet50_mamografie_model


Epoch 3/10: 100%|██████████| 30/30 [01:32<00:00,  3.08s/it, Loss=0.0175, Acc=71.34%]


Epoch 3: Train Loss = 0.0175
Epoch 3 → Precision: 0.2650 | Recall: 1.0000 | F1: 0.4189
Etichete reale: Counter({np.int64(0): 88, np.int64(1): 31})
Predictii: Counter({np.int64(1): 117, np.int64(0): 2})
              precision    recall  f1-score   support

      Benign       1.00      0.02      0.04        88
      Malign       0.26      1.00      0.42        31

    accuracy                           0.28       119
   macro avg       0.63      0.51      0.23       119
weighted avg       0.81      0.28      0.14       119
Model salvat in: resnet50_mamografie_model


Epoch 4/10:   3%|▎         | 1/30 [00:12<05:48, 12.02s/it, Loss=0.0150, Acc=68.75%]


KeyboardInterrupt: 

In [ ]:
import matplotlib.pyplot as plt
import pandas as pd

def plot_metrice_din_csv(csv_path="fisier_log.csv"):
    if not os.path.exists(csv_path):
        print("CSV-ul nu exista.")
        return

    df = pd.read_csv(csv_path)
    plt.figure(figsize=(10, 6))
    plt.plot(df["Epoch"], df["Precision"], label="Precision", marker="o")
    plt.plot(df["Epoch"], df["Recall"], label="Recall", marker="o")
    plt.plot(df["Epoch"], df["F1"], label="F1-score", marker="o")

    plt.title("Evolutia metricilor pe epoci")
    plt.xlabel("Epoca")
    plt.ylabel("Valoare metrica")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("evolutie_metrici.png")
    plt.show()

plot_metrice_din_csv()
