In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
###############################################################################

import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False  

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
            
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                        
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)


        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1_1 = nn.Conv2d(1,64,3,1)
        self.conv1_2 = nn.Conv2d(64,64,3,1)
        
        self.conv2_1 = nn.Conv2d(64,128,3,1)
        self.conv2_2 = nn.Conv2d(128,128,3,1)
        
        self.conv3_1 = nn.Conv2d(128,256,3,1)
        self.conv3_2 = nn.Conv2d(256,256,3,1)
        
        self.conv4_1 = nn.Conv2d(256,512,3,1)
        self.conv4_2 = nn.Conv2d(512,512,3,1)
        
        self.conv5_1 = nn.Conv2d(512,1024,3,1)
        self.conv5_2 = nn.Conv2d(1024,1024,3,1)
        self.part2d_1 = PartialConv2d(1024,1024,2,1)
        
        self.conv6_1 = nn.Conv2d(1024,512,3,1)
        self.conv6_2 = nn.Conv2d(512,512,3,1)
        self.part2d_2 = PartialConv2d(512,512,2,1)
        
        self.conv7_1 = nn.Conv2d(512,256,3,1)
        self.conv7_2 = nn.Conv2d(256,256,3,1)
        self.part2d_3 = PartialConv2d(256,256,2,1)
        
        self.conv8_1 = nn.Conv2d(256,128,3,1)
        self.conv8_2 = nn.Conv2d(128,128,3,1)
        self.part2d_4 = PartialConv2d(128,128,2,1)
        
        self.conv9_1 = nn.Conv2d(128,64,3,1)
        self.conv9_2 = nn.Conv2d(64,64,3,1)
        
        self.conv10 = nn.Conv2d(64,3,1,1)
        
    def forward(self. x):
        x = self.conv1_1(x)
        x = F.relu(x)
        x = self.conv1_2(x)
        cat_1 = F.relu(x)
        x = F.max_pool2d(cat_1, 2)
        
        x = self.conv2_1(x)
        x = F.relu(x)
        x = self.conv2_2(x)
        cat_2 = F.relu(x)
        x = F.max_pool2d(cat_2, 2)
        
        x = self.conv3_1(x)
        x = F.relu(x)
        x = self.conv3_2(x)
        cat_3 = F.relu(x)
        x = F.max_pool2d(cat_3, 2)
        
        x = self.conv4_1(x)
        x = F.relu(x)
        x = self.conv4_2(x)
        cat_4 = F.relu(x)
        x = F.max_pool2d(cat_4, 2)
        
        x = self.conv5_1(x)
        x = F.relu(x)
        x = self.conv5_2(x)
        x = F.relu(x)
        x = part2d_1(x)
        
        x = torch.cat([cat_4,x], dim=1)
        x = self.conv6_1(x)
        x = F.relu(x)
        x = self.conv6_2(x)
        x = F.relu(x)
        x = part2d_2(x)
        
        x = torch.cat([cat_3,x], dim=1)
        x = self.conv7_1(x)
        x = F.relu(x)
        x = self.conv7_2(x)
        x = F.relu(x)
        x = part2d_3(x)
        
        x = torch.cat([cat_2,x], dim=1)
        x = self.conv8_1(x)
        x = F.relu(x)
        x = self.conv8_2(x)
        x = F.relu(x)
        x = part2d_4(x)
        
        x = torch.cat([cat_1,x], dim=1)
        x = self.conv9_1(x)
        x = F.relu(x)
        x = self.conv9_2(x)
        x = F.relu(x)
        
        x = self.conv10(x)
        
        return x
        
my_nn = Net()
print(my_nn)

In [4]:
%cd /project_stuff/mask_off/img
%ls

D:\project_stuff\mask_off\img
 D 드라이브의 볼륨: 새 볼륨
 볼륨 일련 번호: 5A4F-87F6

 D:\project_stuff\mask_off\img 디렉터리

2021-08-09  오후 02:20    <DIR>          .
2021-08-09  오후 02:20    <DIR>          ..
2016-03-23  오후 07:46    <DIR>          img_align_celeba_png
2021-08-09  오후 02:29    <DIR>          mask_on_png
2021-08-09  오후 02:20    <DIR>          org
               0개 파일                   0 바이트
               5개 디렉터리  519,149,273,088 바이트 남음


In [None]:
import torch.utils.data.Dataset as Dataset
import torch.utils.data.DataLoader as DataLoader
import torchvision
import torchvision.transforms as transforms

img_root = './train_set'

trans = transforms.Compose([transforms.ToTensor(),transform.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

trainset = torchvision.datasets.ImageFolder(root = img_root, transform = trans)
classes = trainset.classes


"""
#https://data-panic.tistory.com/21
#https://data-panic.tistory.com/13

class CustomDataset(Dataset):
    def __init__(self, x, train=True):
        pass
           
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
"""

In [None]:
trainloader = DataLoader(trainset, batch_size = 10, shuffle = False, num_workers = 2)

In [None]:
# from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html, temporary
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(my_nn.parameters(), lr=0.001) # momentum=0.9 

In [None]:
for epoch in range(200):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, targets = data

        optimizer.zero_grad()

        # forward + backward + optimize
        out = my_nn(inputs)
        loss = criterion(out, targets)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

In [None]:
import matplotlib
import matplotlib.pyplot as plt

