In [2]:
# === CNN KLASA 0/1 NA BAZIE ROI ===
import os
import glob
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

# === KONFIGURACJA ===
DATA_DIR = "output_rois/classified"  # folder z dwiema klasami: ok/, defect/
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 16
EPOCHS = 20
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# === TRANSFORMACJA ===
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor()
])

# === DATASET Z ETYKIETAMI ===
class ClassifiedROIs(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.transform = transform
        for label, class_name in enumerate(["ok", "defect"]):
            class_dir = os.path.join(root_dir, class_name)
            for img_path in glob.glob(os.path.join(class_dir, "*.jpg")):
                self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path)
        if self.transform:
            img = self.transform(img)
        # === Diagnostyka obrazu ===
        # print(f"[IMG] {os.path.basename(path)} min: {img.min():.3f}, max: {img.max():.3f}")
        # === Diagnostyka etykiety ===
        # print(f"[LABEL] dtype: {type(label)}, value: {label}")
        label_tensor = torch.tensor(float(label), dtype=torch.float32)
        print(f"[LABEL] tensor: {label_tensor}, dtype: {label_tensor.dtype}")
        return img, label_tensor


# === CNN MODEL ===
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, 64), nn.ReLU(),
            nn.Linear(64, 1)  # bez Sigmoid
        )

    def forward(self, x):
        x = self.cnn(x)
        return self.classifier(x)

# === PRZYGOTOWANIE DANYCH ===
dataset = ClassifiedROIs(DATA_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# === MODEL I TRENOWANIE ===
model = SimpleCNN().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    correct, total = 0, 0
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE).unsqueeze(1)
        outputs = model(imgs)
        print("[OUTPUT LOGITS]", outputs.view(-1).detach().cpu().numpy())


        # === Diagnostyka outputu ===
        # print(f"[OUTPUT] min: {outputs.min().item():.3f}, max: {outputs.max().item():.3f}")

        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {running_loss:.4f} | Accuracy: {acc:.2f}%")

# === ZAPIS MODELU ===
torch.save(model.state_dict(), "cnn_roi_classifier.pth")
print("✅ Model zapisany jako cnn_roi_classifier.pth")


[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 1.0, dtype: torch.float32
[OUTPUT LOGITS] [-0.11260543 -0.11200909 -0.10693707 -0.11311141 -0.10910647 -0.11271537
 -0.11272354 -0.11275409 -0.11261459 -0.10629951 -0.1116506  -0.10929574
 -0.11259114 -0.11198962 -0.11258961 -0.1142038 ]
[LABEL] tensor: 1.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torch.float32
[LABEL] tensor: 0.0, dtype: torc