In [1]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid, save_image
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image

from models import *
from utils import *

In [2]:
lr = 0.0002
epochs = 100
batch_size = 4

display_step = 20

in_channels = 3
out_channels = 3
hidden_channels = 64
depth = 4

n_gpu = torch.cuda.device_count()
device = torch.device("cuda:0" if n_gpu else "cpu")

model = UNet(
    in_channels, out_channels, hidden_channels, depth
    ).to(device)

# criterion = nn.L1Loss()
criterion = nn.MSELoss()

optim = torch.optim.Adam(model.parameters(), lr=lr)

train_filepath = "../data/unet"
input_size = (256, 256)
loader, _ = load_datasets_UNet(
    train_filepath=train_filepath, val_filepath=None, 
    crop_size=1024, new_size=(256, 256), batch_size=batch_size
    )

In [3]:
model

UNet(
  (set_hidden_channels): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
  (contracts): ModuleList(
    (0): UNetContractingBlock(
      (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): LeakyReLU(negative_slope=0.2)
    )
    (1): UNetContractingBlock(
      (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): LeakyReLU(negative_slope=0.2)
    )
    (2): UNetContractingBlock(
      (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): LeakyReLU(negative_slope=0.2)
    )
    (3): UNetContractingBlock(
      (conv): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1

In [4]:
mean_loss = 0
counter = 0

for epoch in range(epochs):
    for data in loader:
        
        model.train()

        x = data[0].to(device)
        y = x.data

        optim.zero_grad()

        out = model(x)

        loss = criterion(out, y)
        loss.backward()
        optim.step()
        
        mean_loss = (mean_loss * counter + loss.item()) / (counter + 1)

        if counter % 10 == 0:
            print(f"Epoch {epoch} Iter {counter} LOSS: {loss.item()}")
        if counter % 50 == 0:
            print(f"Epoch {epoch} Iter {counter} MEAN LOSS: {mean_loss}")

        if counter % display_step == 0:
            model.eval()
            in_out_tensor = torch.cat([x.detach(), out.detach()], axis=0)
            grid = make_grid(
                in_out_tensor.cpu(), nrow = data[0].shape[0]
                )
            save_image(grid, f"imgs/{counter}.png")
        
        counter += 1


Epoch 0 Iter 0 LOSS: 0.5998914837837219
Epoch 0 Iter 0 MEAN LOSS: 0.5998914837837219
Epoch 0 Iter 10 LOSS: 0.23619580268859863
Epoch 0 Iter 20 LOSS: 0.16824403405189514
Epoch 0 Iter 30 LOSS: 0.0832994282245636
Epoch 0 Iter 40 LOSS: 0.04227578267455101
Epoch 0 Iter 50 LOSS: 0.039537377655506134
Epoch 0 Iter 50 MEAN LOSS: 0.17200948283368467
Epoch 0 Iter 60 LOSS: 0.014808445237576962
Epoch 0 Iter 70 LOSS: 0.023995641618967056
Epoch 0 Iter 80 LOSS: 0.010572623461484909
Epoch 0 Iter 90 LOSS: 0.010635352693498135
Epoch 0 Iter 100 LOSS: 0.010946562513709068
Epoch 0 Iter 100 MEAN LOSS: 0.09768987077260669
Epoch 0 Iter 110 LOSS: 0.007963653653860092
Epoch 0 Iter 120 LOSS: 0.019558124244213104
Epoch 0 Iter 130 LOSS: 0.005524545907974243
Epoch 0 Iter 140 LOSS: 0.0064530945383012295
Epoch 0 Iter 150 LOSS: 0.005934461951255798
Epoch 0 Iter 150 MEAN LOSS: 0.06860275077479369
Epoch 0 Iter 160 LOSS: 0.014609086327254772
Epoch 0 Iter 170 LOSS: 0.006573219783604145
Epoch 0 Iter 180 LOSS: 0.003929767291

KeyboardInterrupt: 