In [1]:
# train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

from dataset import Sentinel2InpaintingDataset
from model import ConvAutoEncoder, UNetAutoEncoder

data_root='D:/s2a.tar/s2a'

# In train.py, use target_size that's divisible by 32:
dataset = Sentinel2InpaintingDataset(
    root_dir=data_root,
    mask_type='random',
    augment=True,
    target_size=(256, 256)  # or (256, 256), (384, 384), etc.
)


 54%|█████▍    | 2526/4640 [04:25<03:42,  9.51it/s]


KeyboardInterrupt: 

In [None]:
def train_autoencoder(
    data_root='D:/s2a',
    model_type='unet',  # 'simple' or 'unet'
    batch_size=4,
    num_epochs=10,
    learning_rate=1e-4,
    latent_dim=256,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    save_dir='checkpoints'
):
    os.makedirs(save_dir, exist_ok=True)
    
    # Create dataset
    print("Loading dataset...")
    
    
    # Train/val split (90/10)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory= True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Create model
    print(f"Creating {model_type} model...")
    if model_type == 'simple':
        model = ConvAutoEncoder(latent_dim=latent_dim)
    else:
        model = UNetAutoEncoder(latent_dim=latent_dim)
    
    model = model.to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # Training loop
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        batch_count = 0
        for batch in pbar:
            batch_count += 1
            if(batch_count > 100): break
            original = batch['original'].to(device)
            masked = batch['masked'].to(device)
            mask = batch['mask'].to(device)
            
            # Forward pass
            recon, latent = model(masked)
            print(masked.shape)
            
            # Calculate loss (can weight the masked region more)
            loss = criterion(recon, original)
            
            # Optional: Focus more on masked regions
            # masked_region_loss = criterion(recon * (1 - mask), original * (1 - mask))
            # loss = 0.7 * loss + 0.3 * masked_region_loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            batch_count = 0
            for batch in val_loader:
                batch_count += 1
                if(batch_count > 10): break
                original = batch['original'].to(device)
                masked = batch['masked'].to(device)
                
                recon, latent = model(masked)
                loss = criterion(recon, original)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        
        print(f'Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # # Save best model
        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     torch.save({
        #         'epoch': epoch,
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'train_loss': train_loss,
        #         'val_loss': val_loss,
        #     }, os.path.join(save_dir, 'best_model.pth'))
        #     print(f'Saved best model with val_loss = {val_loss:.6f}')
        
        # # Save checkpoint every 10 epochs
        # if (epoch + 1) % 1 == 0:
        #     torch.save({
        #         'epoch': epoch,
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'train_loss': train_loss,
        #         'val_loss': val_loss,
        #     }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.title('Training Progress')
    plt.savefig(os.path.join(save_dir, 'training_curves.png'))
    plt.close()
    
    return model

if __name__ == '__main__':
    model = train_autoencoder(
        data_root='D:/s2a.tar/s2a',
        model_type='simple',  # Use U-Net for better results
        batch_size=4,       # Adjust based on GPU memory
        num_epochs=10,
        learning_rate=1e-4,
        latent_dim=256
    )
 

Loading dataset...
Creating simple model...
Model parameters: 5,476,236


Epoch 1/10:   2%|▏         | 100/5366 [01:18<1:08:34,  1.28it/s, loss=0.0619]


Epoch 1: Train Loss = 0.001507, Val Loss = 0.000819
Saved best model with val_loss = 0.000819


Epoch 2/10:   2%|▏         | 100/5366 [01:17<1:08:16,  1.29it/s, loss=0.0411]


Epoch 2: Train Loss = 0.000772, Val Loss = 0.000550
Saved best model with val_loss = 0.000550


Epoch 3/10:   2%|▏         | 100/5366 [01:15<1:06:33,  1.32it/s, loss=0.0229]


Epoch 3: Train Loss = 0.000531, Val Loss = 0.000510
Saved best model with val_loss = 0.000510


Epoch 4/10:   2%|▏         | 100/5366 [01:16<1:06:50,  1.31it/s, loss=0.0248]


Epoch 4: Train Loss = 0.000441, Val Loss = 0.000466
Saved best model with val_loss = 0.000466


Epoch 5/10:   2%|▏         | 100/5366 [01:18<1:08:53,  1.27it/s, loss=0.0467]


Epoch 5: Train Loss = 0.000494, Val Loss = 0.000358
Saved best model with val_loss = 0.000358


Epoch 6/10:   2%|▏         | 100/5366 [01:11<1:02:46,  1.40it/s, loss=0.00936]


Epoch 6: Train Loss = 0.000348, Val Loss = 0.000325
Saved best model with val_loss = 0.000325


Epoch 7/10:   2%|▏         | 100/5366 [01:07<59:32,  1.47it/s, loss=0.00972] 


Epoch 7: Train Loss = 0.000338, Val Loss = 0.000281
Saved best model with val_loss = 0.000281


Epoch 8/10:   2%|▏         | 100/5366 [01:10<1:02:12,  1.41it/s, loss=0.0304]


Epoch 8: Train Loss = 0.000261, Val Loss = 0.000238
Saved best model with val_loss = 0.000238


Epoch 9/10:   2%|▏         | 100/5366 [01:08<59:48,  1.47it/s, loss=0.0111]  


Epoch 9: Train Loss = 0.000256, Val Loss = 0.000229
Saved best model with val_loss = 0.000229


Epoch 10/10:   2%|▏         | 100/5366 [01:13<1:04:45,  1.36it/s, loss=0.0245]


Epoch 10: Train Loss = 0.000206, Val Loss = 0.000197
Saved best model with val_loss = 0.000197
