In [1]:
import os
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

seq_file_name = '625_38n18_1_2mm_-161_07_41_19_806'

# --- ŚCIEŻKI ---
NEW_IMAGE_DIR = f'../frames_output/{seq_file_name}/preview_fixed'  # zmień na swój katalog
MODEL_ARC_PATH = '../models/autoencoder_arc_reference.pth'
MODEL_WELD_PATH = '../models/autoencoder_weld_reference.pth'
# /home/dburcon/studia/SeqDataReadingProject/models
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMAGE_SIZE = (64, 64)

# --- TRANSFORMACJE ---
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

# --- STRUKTURA MODELU ---
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

# --- ŁADOWANIE MODELÓW ---
model_arc = ConvAutoencoder().to(DEVICE)
model_arc.load_state_dict(torch.load(MODEL_ARC_PATH))
model_arc.eval()

model_weld = ConvAutoencoder().to(DEVICE)
model_weld.load_state_dict(torch.load(MODEL_WELD_PATH))
model_weld.eval()

# --- DATASET Z ROIs ---
class ThermalDatasetMultiROI(Dataset):
    def __init__(self, image_dir, transform=None):
        self.files = sorted(glob.glob(os.path.join(image_dir, "*.jpg")))
        self.transform = transform
        self.roi_arc = (295, 410, 345, 480)
        self.roi_weld = (270, 250, 370, 400)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx])
        arc_crop = img.crop(self.roi_arc)
        weld_crop = img.crop(self.roi_weld)
        if self.transform:
            arc_crop = self.transform(arc_crop)
            weld_crop = self.transform(weld_crop)
        return arc_crop, weld_crop, os.path.basename(self.files[idx])

# --- DANE I PRZETWARZANIE ---
dataset = ThermalDatasetMultiROI(NEW_IMAGE_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
criterion = nn.MSELoss()

errors_arc, errors_weld, filenames = [], [], []

with torch.no_grad():
    for arc, weld, fname in dataloader:
        arc = arc.to(DEVICE)
        weld = weld.to(DEVICE)

        recon_arc = model_arc(arc)
        recon_weld = model_weld(weld)

        err_arc = criterion(recon_arc, arc).item()
        err_weld = criterion(recon_weld, weld).item()

        errors_arc.append(err_arc)
        errors_weld.append(err_weld)
        filenames.append(fname[0])

# --- PROGI I WYKRYWANIE ---
threshold_arc = np.percentile(errors_arc, 95)
threshold_weld = np.percentile(errors_weld, 95)

anomalies_arc = [i for i, e in enumerate(errors_arc) if e > threshold_arc]
anomalies_weld = [i for i, e in enumerate(errors_weld) if e > threshold_weld]

# --- WIZUALIZACJA ---
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(errors_arc, label="Arc Error")
plt.axhline(threshold_arc, color='r', linestyle='--', label='Threshold')
plt.scatter(anomalies_arc, [errors_arc[i] for i in anomalies_arc], color='orange', label='Anomalies')
plt.title("Anomalie łuku")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(errors_weld, label="Weld Error")
plt.axhline(threshold_weld, color='r', linestyle='--', label='Threshold')
plt.scatter(anomalies_weld, [errors_weld[i] for i in anomalies_weld], color='orange', label='Anomalies')
plt.title("Anomalie spoiny")
plt.legend()

plt.tight_layout()
plt.show()

IndexError: index -1 is out of bounds for axis 0 with size 0