In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from conv_autoencoder import conv_autoencoder
from data_class import highres_img_dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import os


In [16]:
output_dir = 'images/training_output'
os.makedirs(output_dir, exist_ok=True)

In [17]:
img_dataset = highres_img_dataset(image_dir='images/training_set')
dataloader = DataLoader(img_dataset, batch_size=16, shuffle=True)

In [18]:
model = conv_autoencoder()
print(model)

conv_autoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Mish(inplace=True)
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Mish(inplace=True)
  )
  (decoder): Sequential(
    (0): Upsample(scale_factor=3.0, mode='bilinear')
    (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Mish(inplace=True)
    (3): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Sigmoid()
  )
)


In [19]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=.001)

In [None]:
n_epochs = 30
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in tqdm(range(n_epochs)):
    training_loss = 0.0
    model.train()
    
    # model training
    for data in dataloader:
        
        for lr_imgs, hr_imgs in dataloader:
            lr_imgs.to(device)
            hr_imgs.to(device)
            
            optimizer.zero_grad()
            
            
            outputs = model(lr_imgs)
            
            loss = criterion(outputs, hr_imgs)
            
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item() * lr_imgs.size(0)
            
    epoch_loss = training_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            sample_lr_imgs, sample_hr_imgs = next(iter(dataloader))
            sample_lr_imgs = sample_lr_imgs.to(device)
            sample_outputs = model(sample_lr_imgs)
            # Save the output images
            save_image(sample_outputs, os.path.join(output_dir, f'output_epoch_{epoch+1}.jpg'))
            print(f"Saved output images for epoch {epoch+1}")
        
    
            
            
        

  0%|          | 0/30 [00:00<?, ?it/s]