In [None]:
import torch
import os
import shutil
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import List, Dict, Any


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model architecture
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 4)  # 4 classes: 0, 90, 180, 270

# Load saved weights
checkpoint = torch.load("checkpoints/best_model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()


In [None]:
CROPS_DIR = Path("../data/rotation/classification/test")

In [None]:

ALL_ANGLES_CROPS = Path("../data/rotation/classification/all_angles_from_test")

os.makedirs("../data/rotation/classification/all_angles_from_test", exist_ok=True)

for p in CROPS_DIR.iterdir():
    print(p)
    
    for f in p.iterdir():
        print(f.name)

        shutil.move(f, ALL_ANGLES_CROPS / f.name)

In [None]:
for f in ALL_ANGLES_CROPS.iterdir():
    print(f.name)
    if ("_45.png" in f.name):
        os.remove(f)

## Move crops from cad crops clean to test crops in order to test inference

In [None]:

CAD_CROPS_CLEAN = Path("../data/cad_crops/")
TEST_CROPS = Path("../data/test_crops")


for p in CAD_CROPS_CLEAN.iterdir():
    if (".jpg" in p.name):
        print(p)
        shutil.move(p, TEST_CROPS / p.name)

In [None]:
import cv2
from pathlib import Path
import numpy as np
from collections import defaultdict

def rotate_patch(patch: np.ndarray, angle: int) -> np.ndarray:
    h, w = patch.shape[:2]
    M = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1.0)
    cos, sin = abs(M[0, 0]), abs(M[0, 1])
    new_w = int(h * sin + w * cos)
    new_h = int(h * cos + w * sin)
    M[0, 2] += new_w / 2 - w / 2
    M[1, 2] += new_h / 2 - h / 2

    border = (0, 0, 0, 0) if patch.shape[2] == 4 else (255, 255, 255)
    return cv2.warpAffine(
        patch,
        M,
        (new_w, new_h),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=border
    )

# Config
TEST_CROPS = Path("../data/test_crops")
NEW_ANGLES = [90, 180, 270]
ANGLE_SUFFIXES = [0] + NEW_ANGLES
VALID_EXTS = {".jpg", ".jpeg", ".png"}

def standardize_and_rotate():
    # Step 1: count files before
    files_before = list(TEST_CROPS.glob("*"))
    original_img_count = len([f for f in files_before if f.suffix.lower() in VALID_EXTS])

    print(f"\n📸 Images before processing: {original_img_count}")

    # Process each image
    for img_path in files_before:
        if not img_path.is_file() or img_path.suffix.lower() not in VALID_EXTS:
            continue

        img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
        if img is None:
            print(f"⚠️ Failed to load {img_path}")
            continue

        base_name = img_path.stem.split("_")[0]
        new_name = f"{base_name}_0.png"
        new_path = TEST_CROPS / new_name

        cv2.imwrite(str(new_path), img)
        print(f"✅ Renamed/saved: {new_name}")

        for angle in NEW_ANGLES:
            out_name = f"{base_name}_{angle}.png"
            out_path = TEST_CROPS / out_name

            if not out_path.exists():
                rotated = rotate_patch(img, angle)
                cv2.imwrite(str(out_path), rotated)
                print(f"🌀 Rotated {angle}° -> {out_name}")

        # Delete old file if different from new path
        if img_path != new_path:
            img_path.unlink()

    # Step 2: count files after
    files_after = list(TEST_CROPS.glob("*.png"))
    final_count = len(files_after)

    # Step 3: group by base name
    base_to_angles = defaultdict(set)
    for f in files_after:
        parts = f.stem.split("_")
        if len(parts) == 2 and parts[1].isdigit():
            base, angle = parts
            base_to_angles[base].add(int(angle))

    expected_total = len(base_to_angles) * 4
    missing_images = expected_total - final_count

    print(f"\n🧮 Final stats:")
    print(f"➡️ Total unique base images: {len(base_to_angles)}")
    print(f"➡️ Expected image count (4 per base): {expected_total}")
    print(f"✅ Images after processing: {final_count}")
    print(f"❌ Missing images: {missing_images}")

if __name__ == "__main__":
    standardize_and_rotate()


In [None]:


transform = transforms.Compose([
    transforms.Resize((300, 300)),  # Match training resolution
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


## Test Inference Model with rotated Data in test_crops

In [None]:
from pathlib import Path
from PIL import Image
import torch
import cv2
import matplotlib.pyplot as plt

suffix = ".png"

def predict_images():
    CLASS_NAMES = [0, 180, 270, 90]

    total_angle_off = 0
    count = 0
    false_pred_count = 0
    printed_errors = 0
    max_to_print = 30
    logs = []

    for f in TEST_CROPS.iterdir():
        if f.suffix.lower() != suffix:
            continue

        img = Image.open(f).convert("RGB")
        input_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)
            _, predicted = torch.max(output, 1)

        pred_angle = CLASS_NAMES[predicted.item()]

        try:
            angle = int(f.stem.split("_")[-1])
        except ValueError:
            msg = f"⚠️ Could not extract angle from filename: {f.name}"
            print(msg)
            logs.append(msg)
            continue

        angle_diff = (pred_angle - angle) % 360
        angle_diff = min(angle_diff, 360 - angle_diff)
        total_angle_off += angle_diff

        correct = (angle == pred_angle)
        if not correct:
            false_pred_count += 1
            msg = f"❌ {f.name} | GT: {angle}°, Pred: {pred_angle}°"
            logs.append(msg)

            if printed_errors < max_to_print:
                print(msg)
                plt.imshow(img)
                plt.title(f"{f.name} | Predicted: {pred_angle}° | ✗")
                plt.axis("off")
                plt.show()
                printed_errors += 1

        count += 1

    summary = f"🧮 Total angular error: {total_angle_off}°, Incorrect predictions: {false_pred_count}/{count}"
    print(summary)
    logs.append(summary)

    # Speichere alles in Log-Datei
    with open("prediction_log.txt", "w", encoding="utf-8") as f_out:
        f_out.write("\n".join(logs))

    return summary


if __name__ == "__main__":
    print(predict_images())
