In [1]:
#main.py

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.transforms import Resize
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from helper_functions import display_denoising_images
from training_functions import train_denoising_model, validate_model
from autoencoders import Autoencoder1, Autoencoder2, Autoencoder3
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from custom_dataset import CustomImageDataset
from noise import apply_scanning_artifacts

# Set random seed for reproducibility
torch.manual_seed(0)

# Set environment variables
TRAIN_FUNCTION = train_denoising_model
MODEL = Autoencoder3()
EPOCHS = 20
DEVICE = "mps"  # Replace with "cuda" if you have a GPU
NOISE_PARAMS = {
    "noise": 0.075,
    "warp": 0.5,
    "speckle": 0.4,
    "streak": 0.4,
    "rotate": 0.1
}

image_dataset_path = "/Users/leo/Programming/autoencoder/data/TextImages/train_cleaned"
transform = transforms.Compose([
    Resize((256, 256)),  # Resize all images to 224x224
    transforms.ToTensor()
])
full_dataset = CustomImageDataset(image_dataset_path, transform=transform)

# Splitting and DataLoader remains the same...
train_size = int(0.75 * len(full_dataset))
valid_size = len(full_dataset) - train_size
train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

# Get some validation images
valid_dataiter = iter(valid_loader)
first_n_valid_images, _ = next(valid_dataiter)
first_n_valid_images = first_n_valid_images[:10]
input_dim = first_n_valid_images.shape[-1] * first_n_valid_images.shape[-2] * first_n_valid_images.shape[-3]



In [2]:
# Simplify the model and training setup
MODEL = Autoencoder3().to(DEVICE)
optimizer = torch.optim.Adam(MODEL.parameters(), lr=0.001)
criterion = nn.MSELoss().to(DEVICE)

# Train the model
train_losses, valid_losses = [], []

for epoch in range(EPOCHS):
    train_loss = train_denoising_model(MODEL, train_loader, optimizer, criterion, DEVICE, apply_scanning_artifacts, **NOISE_PARAMS)
    valid_loss = validate_model(MODEL, valid_loader, DEVICE, apply_scanning_artifacts, **NOISE_PARAMS)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")


    # Display images
    display_denoising_images(
        4,
        first_n_valid_images,
        MODEL,
        DEVICE,
        apply_scanning_artifacts,
        **NOISE_PARAMS
    )

ValueError: axes don't match array