In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
from multi_digit_cnn import MultiDigitCNN  # Import the CNN model
from image_preprocessing import preprocess_image  # Import preprocessing function

In [111]:
import pandas as pd
import h5py
import numpy as np

columns = ["filename", "labels"]
df_train = pd.DataFrame(columns=columns)
df_test = pd.DataFrame(columns=columns)

# Define path to digitStruct.mat
train_file_path = "data/train/digitStruct.mat"  # Update with actual path
test_file_path = "data/test/digitStruct.mat"

f = h5py.File(train_file_path, 'r')
bbox_train_dataset = f.get('digitStruct/bbox')
num_train_images = len(bbox_train_dataset)


def extract_labels(bbox_dataset, img_num):
    bbox_ref = bbox_dataset[img_num][0]
    label_ref = f[bbox_ref]["label"]

    # Handle single-label case (directly stored as a number)
    if label_ref.shape[0] == 1:  # Shape is empty -> single value
        labels = np.array([int(label_ref[()].item())])  # Convert directly
    else:
        # Multiple labels (stored as references)
        labels = np.array([int(f[label_ref[i][0]][()].item()) for i in range(label_ref.shape[0])])
        
    return labels

# Loop through images and print their labels
for i in range(num_train_images):
    df_train.loc[len(df_train)] = [f"{i+1}.png", extract_labels(bbox_train_dataset, i)]   

df_train.to_csv("data/train/cleaned_train_labels.csv")

f = h5py.File(test_file_path, 'r')
bbox_test_dataset = f.get('digitStruct/bbox')
num_test_images = len(bbox_test_dataset)

for i in range(num_test_images):
    df_test.loc[len(df_test)] = [f"{i+1}.png", extract_labels(bbox_test_dataset, i)]  


df_train.to_csv("data/test/cleaned_test_labels.csv") 

In [11]:
# Load SVHN dataset with preprocessing pipeline
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((1024, 1024)),  # Resize as per the paper
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1,1]
])

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class SVHNDataset(Dataset):
    def __init__(self, root_dir, annotations, transform=None, max_digits=18, pad_value=-1):
        """
        Args:
            root_dir (str): Path to the directory containing images.
            annotations (DataFrame): DataFrame with 'filename' and 'labels'.
            transform (callable, optional): Image transformations.
            max_digits (int): Maximum number of digits per image.
            pad_value (int): Value used for padding shorter label sequences.
        """
        self.root_dir = root_dir
        self.annotations = annotations
        self.transform = transform
        self.max_digits = max_digits
        self.pad_value = pad_value

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # Load image
        img_name = os.path.join(self.root_dir, self.annotations.iloc[idx]["filename"])
        image = Image.open(img_name).convert("RGB")

        # Load multi-digit label
        labels = self.annotations.iloc[idx]["labels"]

        # Convert labels to fixed-length tensor (with padding)
        label_tensor = torch.full((self.max_digits,), self.pad_value, dtype=torch.long)  # Initialize padding
        label_tensor[:len(labels)] = torch.tensor(labels, dtype=torch.long)  # Fill with actual labels

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        return image, label_tensor



In [23]:
import ast
import re
import pandas as pd
# Define dataset directory and DataFrame (df_corrected contains filename-label mapping)
image_train_dir = "data/train"  # Change this to your actual image directory
train_annotations_path = "data/train/cleaned_train_labels.csv"

image_test_dir = "data/test"  # Change this to your actual image directory
test_annotations_path = "data/test/cleaned_test_labels.csv"

df_train = pd.read_csv(train_annotations_path)
df_test = pd.read_csv(test_annotations_path)

def fix_label_format(label):
    if isinstance(label, str):
        # Replace spaces with commas inside brackets (fix missing commas)
        fixed_label = re.sub(r"\[([0-9\s]+)\]", lambda m: "[" + ",".join(m.group(1).split()) + "]", label)
        return fixed_label
    return label

df_train["labels"] = df_train["labels"].apply(fix_label_format)
df_train["labels"] = df_train["labels"].apply(ast.literal_eval)

df_test["labels"] = df_test["labels"].apply(fix_label_format)
df_test["labels"] = df_test["labels"].apply(ast.literal_eval)


# Create dataset
train_dataset = SVHNDataset(root_dir=image_train_dir, annotations=df_train, transform=transform)
test_dataset = SVHNDataset(root_dir=image_test_dir, annotations=df_test, transform=transform)


# Create DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


# Get a batch to verify
images, labels = next(iter(train_loader))

# Display shapes
print("Image batch shape:", images.shape)  # Expected: (batch_size, 3, 64, 64)
print("Labels batch shape:", labels.shape)  # Expected: (batch_size, max_digits)


OSError: [Errno 28] No space left on device

In [24]:
# Testin gmodel if it works with dummy tensor
import torch
from multi_digit_cnn import MultiDigitCNN  # Import the CNN model

dummy_input = torch.randn(1, 1, 1024, 1024)
model = MultiDigitCNN()
output = model(dummy_input)
print("Output shape:", output.shape)

Output shape: torch.Size([1, 18, 10])


In [22]:
import torch
from multi_digit_cnn import MultiDigitCNN  # Import the CNN model
from training_multi_digit_cnn import train_model

model = MultiDigitCNN()

#setting up training

train_model(model, batch_size=32, learning_rate=0.0001, train_loader=train_loader,test_loader=test_loader, save_path="digits_model_Feb_16, 2025")



RuntimeError: DataLoader worker (pid 24107) is killed by signal: Killed. 