In [1]:
# main.py

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from training_functions import train_denoising_model, validate_model
from autoencoders import Autoencoder1
import torch.optim as optim
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(0)

# Set environment variables
TRAIN_FUNCTION = train_denoising_model
MODEL = Autoencoder2
EPOCHS = 2
DEVICE = "mps"
NOISE_FACTOR = 0.1

# Load data and set DataLoader
transform = transforms.Compose([transforms.ToTensor()])
full_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)

# Split dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
valid_size = len(full_dataset) - train_size
train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)

# Initialize model, criterion, and optimizer
model = MODEL().to(DEVICE)
criterion = nn.MSELoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and Validation
train_losses = []
valid_losses = []

for epoch in range(EPOCHS):

    # Training
    train_loss = 0.0
    pbar = tqdm(enumerate(train_loader), desc="Processing", total=len(train_loader), leave=True)  # tqdm for batches
    for batch_idx, (data, _) in pbar:
        loss = TRAIN_FUNCTION(model, data, NOISE_FACTOR, optimizer, criterion, DEVICE, should_flatten=False)
        train_loss += loss
        pbar.set_postfix({'Batch Train Loss': f"{loss:.4f}"})

    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    valid_loss = validate_model(model, valid_loader, device=DEVICE, should_flatten=False)
    valid_losses.append(valid_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Average Train Loss: {avg_train_loss:.4f}, Average Validation Loss: {valid_loss:.4f}")


Files already downloaded and verified


Processing: 100%|██████████| 313/313 [00:28<00:00, 11.15it/s, Batch Train Loss=0.0052]


Validation Loss: 0.0037
Epoch [1/2] Average Train Loss: 0.0087, Average Validation Loss: 0.0037


Processing: 100%|██████████| 313/313 [00:22<00:00, 13.97it/s, Batch Train Loss=0.0029]


Validation Loss: 0.0025
Epoch [2/2] Average Train Loss: 0.0037, Average Validation Loss: 0.0025
