In [None]:
# 1. Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import pydicom
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from collections import Counter
from PIL import UnidentifiedImageError
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from datetime import datetime
import os

In [None]:

# %%
# 2. Paths
base_dir = "../data/ct_scans"
csv_path = os.path.join(base_dir, "overview.csv")
tiff_dir = os.path.join(base_dir, "tiff_images")
dcm_dir = os.path.join(base_dir, "dicom_dir")
dcm_dir = os.path.join(base_dir, "ctscan_png")

In [None]:
# %%
# 3. Transform
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [None]:
class CTDataset(Dataset):
    def __init__(self, csv_file, tiff_dir, dcm_dir, transform=None, use_tiff=True):
        self.data = pd.read_csv(csv_file)
        self.tiff_dir = tiff_dir
        self.dcm_dir = dcm_dir
        self.transform = transform
        self.use_tiff = use_tiff   # True → use TIFF images, False → use DICOM
        # Map labels (you can adapt this mapping if your CSV has real labels)
        self.label_map = {True: 1, False: 0}
        # For now, we assign dummy labels (e.g., "normal") → you can update this later
        # self.data["label"] = "normal"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        if self.use_tiff:
            img_path = os.path.join(self.tiff_dir, row["tiff_name"])
            try:
                img = Image.open(img_path).convert("RGB")
            except UnidentifiedImageError:
                print(f"Warning: Cannot open image {img_path}, skipping.")
                return self.__getitem__((idx + 1) % len(self.data))
        else:
            img_path = os.path.join(self.dcm_dir, row["dicom_name"])
            dcm = pydicom.dcmread(img_path)
            img = Image.fromarray(dcm.pixel_array).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.label_map[row["Contrast"]]
        return img, label

In [None]:
# 5. Dataset & DataLoader
dataset = CTDataset(csv_path, tiff_dir, dcm_dir, transform=transform, use_tiff=False)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

print("Number of images:", len(dataset))
print("Sample batch:")
imgs, labels = next(iter(dataloader))
print(imgs.shape, labels)


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32*32*32, 128), nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleCNN(num_classes=3)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# 7. Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    all_preds, all_labels = [], []
    for imgs, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {acc:.4f}")

# %%

In [None]:
model_path = "../backend/models/simple_cnn_ct.pth"
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'num_classes': 3}, model_path)
print(f'Model saved to {model_path}')

In [43]:


def detect_anomaly(img_path, prev_img_path, ssim_threshold=0.85):
    """
    Compare current scan with previous one and detect anomalies.
    Returns a string anomaly type if detected, else None.
    """
    if prev_img_path is None:
        return None  # No comparison for the first image
    
    # Load images (grayscale for simplicity)
    img1 = cv2.imread(prev_img_path, cv2.IMREAD_GRAYSCALE)
    img2 = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img1 is None or img2 is None:
        return "Error: Could not load images"
    
    # Resize to same shape
    img1 = cv2.resize(img1, (256, 256))
    img2 = cv2.resize(img2, (256, 256))
    
    # Compute SSIM
    score, diff = ssim(img1, img2, full=True)
    
    if score < ssim_threshold:
        # Simple heuristics for anomaly type
        diff_mean = np.mean(diff)
        if diff_mean < 0.3:
            return "Image Distortion"
        elif np.mean(cv2.absdiff(img1, img2)) > 25:
            return "Organ Shift"
        else:
            return "Instrument Misalignment"
    return None


# --- Only Two Images ---
before_img = "../userImage/image1.png"
after_img = "../userImage/image2.png"

# Get timestamps (from file modification times)
before_time = datetime.fromtimestamp(os.path.getmtime(before_img)).strftime("%H:%M:%S")
after_time  = datetime.fromtimestamp(os.path.getmtime(after_img)).strftime("%H:%M:%S")

# Detect anomaly between before → after
anomaly = detect_anomaly(after_img, before_img)

print(f"Comparing scans:")
print(f" - Before ({before_time}): {before_img}")
print(f" - After  ({after_time}): {after_img}")

if anomaly:
    print(f"⚠️ Anomaly Detected between scans: {anomaly}")
else:
    print("✅ No anomaly detected between scans.")


Comparing scans:
 - Before (10:45:00): ../userImage/image1.png
 - After  (10:44:40): ../userImage/image2.png
⚠️ Anomaly Detected between scans: Image Distortion
