In [1]:
# main.py

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from custom_dataset import CustomImageDataset
from training_functions import train_denoising_model, validate_model
from torchvision.transforms import Resize
from autoencoders import Autoencoder1
import torch.optim as optim
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(0)

# Set environment variables
TRAIN_FUNCTION = train_denoising_model
MODEL = Autoencoder1
EPOCHS = 2
DEVICE = "cpu"
NOISE_FACTOR = 0.1
DATASET_PATH = "/Users/leo/Programming/autoencoder/data/TextImages/train_cleaned"

transform = transforms.Compose([
    Resize((256, 256)),  # Resize all images to 224x224
    transforms.ToTensor()
])
full_dataset = CustomImageDataset(DATASET_PATH, transform=transform)

# Split dataset into train and validation sets
train_size = int(0.75 * len(full_dataset))
valid_size = len(full_dataset) - train_size
train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

# Initialize model, criterion, and optimizer
model = MODEL().to(DEVICE)
criterion = nn.MSELoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and Validation
train_losses = []
valid_losses = []

for epoch in range(EPOCHS):

    # Training
    train_loss = 0.0
    pbar = tqdm(enumerate(train_loader), desc="Processing", total=len(train_loader), leave=True)  # tqdm for batches
    for batch_idx, (data, _) in pbar:
        loss = TRAIN_FUNCTION(model, data, NOISE_FACTOR, optimizer, criterion, DEVICE)
        train_loss += loss
        pbar.set_postfix({'Batch Train Loss': f"{loss:.4f}"})

    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    valid_loss = validate_model(model, valid_loader, device=DEVICE)
    valid_losses.append(valid_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Average Train Loss: {avg_train_loss:.4f}, Average Validation Loss: {valid_loss:.4f}")


Processing: 100%|██████████| 14/14 [00:11<00:00,  1.22it/s, Batch Train Loss=0.0357]


Validation Loss: 0.0520
Epoch [1/2] Average Train Loss: 0.0573, Average Validation Loss: 0.0520


Processing: 100%|██████████| 14/14 [00:11<00:00,  1.25it/s, Batch Train Loss=0.0308]


Validation Loss: 0.0411
Epoch [2/2] Average Train Loss: 0.0295, Average Validation Loss: 0.0411
