In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import glob
from PIL import Image
import numpy as np 
from torchvision.io import read_image
import cv2
from torchvision.utils import save_image
import os
from torch.utils import data
import cv2
import math
import torchvision

In [2]:
class conv_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
      self.bn1 = nn.InstanceNorm2d(out_c,affine=True)

      self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
      self.bn2 = nn.InstanceNorm2d(out_c,affine=True)

      self.relu = nn.ReLU()

  def forward(self, inputs):
      x = self.conv1(inputs)
      x = self.bn1(x)
      x = self.relu(x)

      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)

      return x

class encoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.conv = conv_block(in_c, out_c)
      self.pool = nn.MaxPool2d((2, 2))

  def forward(self, inputs):
      x = self.conv(inputs)
      p = self.pool(x)

      return x, p

class decoder_block(nn.Module):
  def __init__(self, in_c, out_c):
      super().__init__()

      self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
      self.conv = conv_block(out_c+out_c, out_c)

  def forward(self, inputs, skip):
      x = self.up(inputs)
      x = torch.cat([x, skip], axis=1)
      x = self.conv(x)

      return x


class build_unet(nn.Module):
  def __init__(self):
      super().__init__()

      self.e1 = encoder_block(3, 64)
      self.e2 = encoder_block(64, 128)
      self.e3 = encoder_block(128, 256)
      self.e4 = encoder_block(256, 512)

      self.b = conv_block(512, 1024)

      self.d1 = decoder_block(1024, 512)   
      self.d2 = decoder_block(512, 256)
      self.d3 = decoder_block(256, 128)
      self.d4 = decoder_block(128, 64)

      self.outputs = nn.Conv2d(64, 3, kernel_size=1, padding=0)

  def forward(self, inputs):

      s1, p1 = self.e1(inputs)
      s2, p2 = self.e2(p1)
      s3, p3 = self.e3(p2)
      s4, p4 = self.e4(p3)

      b = self.b(p4)

      d1 = self.d1(b, s4)
      d2 = self.d2(d1, s3)
      d3 = self.d3(d2, s2)
      d4 = self.d4(d3, s1)

      outputs = self.outputs(d4)

      return outputs


In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose(
        [
            A.Rotate(limit=10, p=1.0),
            A.HorizontalFlip(p=1),
         A.RandomCrop(height=256, width=256),
         A.RandomBrightnessContrast(p=0.5),
         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
            A.Normalize(
               mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
            ),
            ToTensorV2(),
        ],
    )

class CreateDataset(data.Dataset):
    def __init__(self):
        self.inputs = sorted(glob.glob('./input/*.jpg'))[:]
        self.targets = sorted(glob.glob('./output/*.jpg'))[:]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self,index: int):

        input_path = self.inputs[index]
        output_path = self.targets[index]
        #print(input_path)
        #print(output_path)
        assert os.path.basename(input_path) == os.path.basename(output_path)

        input = transform(Image.open(input_path)).cuda()
        target = transform(Image.open(output_path)).cuda()

        return input,target

dataset = CreateDataset()
dataloader = data.DataLoader(dataset=dataset,batch_size=1,shuffle=True)

In [5]:
model = build_unet()
model.load_state_dict(torch.load(f='./model126.pth'))
model.cuda()

RuntimeError: ignored

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001/10)
criterion = nn.MSELoss()

In [None]:

model.train()

for epoch in range(10) :
    
    totalLoss = 0

    for idx,(inputs,targets) in enumerate(dataloader):

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs,targets)

        loss.backward()

        optimizer.step()

        totalLoss += loss.item()

        print("Epoch: " + str(epoch) + " Idx : " + str(idx))
        print("Loss: {:.4f}".format(loss.item()))

    print(totalLoss)

Epoch: 0 Idx : 0
Loss: 0.0259
Epoch: 0 Idx : 1
Loss: 0.0322
Epoch: 0 Idx : 2
Loss: 0.0154
Epoch: 0 Idx : 3
Loss: 0.0126
Epoch: 0 Idx : 4
Loss: 0.0187
Epoch: 0 Idx : 5
Loss: 0.0704
Epoch: 0 Idx : 6
Loss: 0.0108
Epoch: 0 Idx : 7
Loss: 0.0132
Epoch: 0 Idx : 8
Loss: 0.0163
Epoch: 0 Idx : 9
Loss: 0.0423
Epoch: 0 Idx : 10
Loss: 0.0432
Epoch: 0 Idx : 11
Loss: 0.0104
Epoch: 0 Idx : 12
Loss: 0.0073
Epoch: 0 Idx : 13
Loss: 0.0470
Epoch: 0 Idx : 14
Loss: 0.0089
Epoch: 0 Idx : 15
Loss: 0.0286
Epoch: 0 Idx : 16
Loss: 0.0124
Epoch: 0 Idx : 17
Loss: 0.0164
Epoch: 0 Idx : 18
Loss: 0.0233
Epoch: 0 Idx : 19
Loss: 0.0252
Epoch: 0 Idx : 20
Loss: 0.0056
Epoch: 0 Idx : 21
Loss: 0.0078
Epoch: 0 Idx : 22
Loss: 0.0760
Epoch: 0 Idx : 23
Loss: 0.0141
Epoch: 0 Idx : 24
Loss: 0.0131
Epoch: 0 Idx : 25
Loss: 0.0081
Epoch: 0 Idx : 26
Loss: 0.0037
Epoch: 0 Idx : 27
Loss: 0.0580
Epoch: 0 Idx : 28
Loss: 0.0113
Epoch: 0 Idx : 29
Loss: 0.0060
Epoch: 0 Idx : 30
Loss: 0.0112
Epoch: 0 Idx : 31
Loss: 0.0195
Epoch: 0 Idx : 32


In [None]:
totalLoss

4.684875009115785

In [None]:
torch.save(obj=model.state_dict(),f='model.pth')

In [None]:
model.eval()

t = transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor()
])

with torch.no_grad() :
    test_input = t(Image.open('./72.jpg'))
    test_input = torch.unsqueeze(test_input,dim=0).cuda()
    test_output = model(test_input)
    test_output = torch.squeeze(test_output)
    save_image(test_output,'output.jpg')
    
model.cuda()

build_unet(
  (e1): encoder_block(
    (conv): conv_block(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (relu): ReLU()
    )
    (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (e2): encoder_block(
    (conv): conv_block(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (relu): ReLU()
    )
    (pool): MaxPool2d(k