In [None]:
import cv2  
from PIL import Image 
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import csv
import random
import torch
import torch.nn as nn
import torchvision.models as models

In [None]:
def preprocess_image(
    image,
    output_size=(28, 28),
    margin=5,
    median_kernel_size=5,
    adaptive_block_size=17,
    adaptive_C=5
):
    # Apply Median Filter to reduce salt-and-pepper noise
    if median_kernel_size % 2 == 0 or median_kernel_size < 1:
        print(f"Warning: median_kernel_size must be a positive odd integer. Using default of 3.")
        median_kernel_size = 3
    gray_filtered = cv2.medianBlur(image, median_kernel_size)

    # Apply adaptive thresholding on the filtered image
    if adaptive_block_size % 2 == 0:
        # print(f"Warning: adaptive_block_size must be odd. Adjusting {adaptive_block_size} to {adaptive_block_size + 1}")
        adaptive_block_size +=1
    
    binary = cv2.adaptiveThreshold(gray_filtered, 255, # Use gray_filtered here
                                    cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY_INV,
                                    adaptive_block_size, adaptive_C)

    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        print(f"Warning: No contours found in image after filtering and adaptive thresholding. Skipping.")
        return None

    cnt = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(cnt)

    original_h_binary, original_w_binary = binary.shape[:2]

    x_margin = max(0, x - margin)
    y_margin = max(0, y - margin)
    w_margin = min(original_w_binary - x_margin, w + 2 * margin)
    h_margin = min(original_h_binary - y_margin, h + 2 * margin)
    
    if w_margin <= 0 or h_margin <= 0:
        print(f"Warning: Margin calculation resulted in non-positive dimension for image. Using original bounding box.")
        dig = binary[y:y+h, x:x+w]
    else:
        dig = binary[y_margin:y_margin+h_margin, x_margin:x_margin+w_margin]

    if dig.size == 0:
        print(f"Warning: Cropping with margin resulted in empty image for image. Skipping.")
        return None
    
    (dh, dw) = dig.shape[:2]
    size_max = max(dw, dh)
    square = np.zeros((size_max, size_max), dtype=np.uint8)

    x_offset = (size_max - dw) // 2
    y_offset = (size_max - dh) // 2
    square[y_offset:y_offset+dh, x_offset:x_offset+dw] = dig

    final_image = cv2.resize(square, output_size, interpolation=cv2.INTER_AREA)

    return final_image

def preprocess_images(
    input_path,
):
    images = []
    labels = []
    for digit_class in range(10):
        image_paths = list(Path(input_path).rglob(f"{digit_class}/*"))
        for image_path in image_paths:
            gray_img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
            if gray_img is None:
                print(f"Warning: Could not read image {image_path}. Skipping.")
                continue
            final_image = preprocess_image(gray_img)
            if final_image is None:
                continue
            images.append(final_image)
            labels.append(digit_class)

    return images, labels


In [None]:
images, labels = preprocess_images('./dataset')

In [None]:
def display_sample_images(images, labels, label):
    plt.figure(figsize=(6, 20))
    j = 0
    for i in range(len(images)):
        if labels[i] != label:
            continue
        plt.subplot(20, 6, j + 1)  # Create a grid with 2 rows and 5 columns
        plt.imshow(images[i], cmap="gray")  # Display the image in grayscale
        plt.axis("off")  # Turn off the axis for cleaner visualization
        j += 1
        if j >= 120:
            break
    
    plt.tight_layout()
    plt.show()
  


In [None]:
display_sample_images(images, labels, label=1)

In [3]:
# --- Configuration ---
MODEL_PATH = 'adapted_resnet_digit_classifier.pth' # Path where you saved your model weights
RESNET_VARIANT = 'resnet34' # Must match the variant you trained (e.g., 'resnet18' or 'resnet34')
NUM_CLASSES = 10


# --- Define Modified ResNet18 (or ResNet34) ---
class AdaptedResNet(nn.Module):
    def __init__(self, num_classes=10, resnet_variant='resnet18'):
        super().__init__()
        if resnet_variant == 'resnet18':
            self.resnet = models.resnet18(weights=None) # No pretrained weights, or use models.ResNet18_Weights.DEFAULT for latest
        elif resnet_variant == 'resnet34':
            self.resnet = models.resnet34(weights=None) # No pretrained weights, or use models.ResNet34_Weights.DEFAULT
        else:
            raise ValueError("resnet_variant must be 'resnet18' or 'resnet34'")

        # 1. Modify the first convolutional layer for 1 input channel and 28x28 images
        # Original: self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Adapted: Use a smaller kernel, stride 1 to preserve dimensions initially
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)

        # 2. Remove or modify the initial MaxPool layer
        # For 28x28, it's often better to remove it or make it identity
        self.resnet.maxpool = nn.Identity() # Effectively removes the max pooling

        # 3. Modify the final fully connected layer for 10 classes
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.resnet(x)

In [None]:


# Assuming 'images_array' is your NumPy array of shape (num_samples, 28, 28)
# Assuming 'labels_array' is your NumPy array of shape (num_samples,)

# Convert NumPy arrays to PyTorch tensors
# Add channel dimension for grayscale: (num_samples, 1, 28, 28)
train_images = torch.tensor(images, dtype=torch.float32).unsqueeze(1)
train_labels = torch.tensor(labels, dtype=torch.long)

# Create a TensorDataset and DataLoader
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(train_images, train_labels)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # Adjust batch_size



# Instantiate the model
# model = AdaptedResNet(num_classes=10, resnet_variant='resnet18')
model = AdaptedResNet(num_classes=NUM_CLASSES, resnet_variant=RESNET_VARIANT) # Or resnet34

# Example: Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# --- Training Loop (Conceptual) ---
num_epochs = 10 # Adjust as needed

print(f"Starting training for {num_epochs} epochs...")
for epoch in range(num_epochs):
    model.train() # Set model to training mode
    running_loss = 0.0
    for i, (batch_images, batch_labels) in enumerate(dataloader):
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_images)
        loss = criterion(outputs, batch_labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 10 == 0: # Print statistics every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {epoch_loss:.4f}")

print("Finished Training")

torch.save(model.state_dict(), MODEL_PATH)
print(f"Model state_dict saved to ${MODEL_PATH}")


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


model = AdaptedResNet(num_classes=NUM_CLASSES, resnet_variant=RESNET_VARIANT)
try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # map_location ensures it loads correctly even if trained on GPU and inferring on CPU
    print(f"Model weights loaded successfully from {MODEL_PATH}")
except FileNotFoundError:
    print(f"ERROR: Model weights file not found at {MODEL_PATH}. Please ensure the path is correct.")
    exit()
except Exception as e:
    print(f"ERROR: Could not load model weights. Reason: {e}")
    exit()

model.to(device)

model.eval()
print("Model set to evaluation mode.")

csv_filename= 'resnet34.csv'
image_paths = list(Path("./testset").rglob("*"))

with open(csv_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["filename", "predicted_label"])
    i = 1
    for image_path in image_paths:
        gray_img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
        print(f"Processing image {i}")
        i += 1
        if gray_img is None:
            print(f"Warning: Could not read image {image_path}. Skipping.")
            writer.writerow([image_path.name, random.randint(0, 9)])  
            continue
        new_image_data_processed = preprocess_image(gray_img)
        if new_image_data_processed is None:
            new_image_data_processed = cv2.resize(gray_img, (28, 28), interpolation=cv2.INTER_AREA)

        input_tensor = torch.tensor(new_image_data_processed, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        input_tensor = input_tensor.to(device)

       
        with torch.no_grad():
            output_logits = model(input_tensor)

        probabilities = torch.softmax(output_logits, dim=1)
        confidence, predicted_class_index = torch.max(probabilities, 1)

        predicted_digit = predicted_class_index.item()

        writer.writerow([image_path.name, predicted_digit])

    print(f"Data successfully written to '{csv_filename}'")



Using device: cpu
Model weights loaded successfully from adapted_resnet_digit_classifier.pth
Model set to evaluation mode.
Processing image 1
Processing image 2
Processing image 3
Processing image 4
Processing image 5
Processing image 6
Processing image 7
Processing image 8
Processing image 9
Processing image 10
Processing image 11
Processing image 12
Processing image 13
Processing image 14
Processing image 15
Processing image 16
Processing image 17
Processing image 18
Processing image 19
Processing image 20
Processing image 21
Processing image 22
Processing image 23
Processing image 24
Processing image 25
Processing image 26
Processing image 27
Processing image 28
Processing image 29
Processing image 30
Processing image 31
Processing image 32
Processing image 33
Processing image 34
Processing image 35
Processing image 36
Processing image 37
Processing image 38
Processing image 39
Processing image 40
Processing image 41
Processing image 42
Processing image 43
Processing image 44
Proces