In [1]:
import torch
import torch.nn as nn
import torchvision
from tensorboardX import SummaryWriter
import torchvision.utils as vutils


#writer = SummaryWriter('runs/exp-1')
writer = SummaryWriter()

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

torch.cuda.get_device_name(0)

'NVIDIA GeForce 940MX'

In [3]:
class Block(nn.Module):
  def __init__(self, in_cn, out_cn):
    super().__init__()
    #cada block te dos convolucions
    self.conv1 = nn.Conv2d(in_cn, out_cn, 3)
    self.conv2 = nn.Conv2d(out_cn, out_cn, 3)
    #definim la funció d'activació
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.relu(self.conv2(self.relu(self.conv1(x))))

In [4]:
class Encoder(nn.Module):
  #a chs passem les dimensions dels Blocks que tindrem
  def __init__(self, chs=(3,64,128,256,512,1024)):
    super().__init__()
    #creem una llista de mòduls amb els blocks,
    #les dimensions seràn i i i+1, sent aquestes valors de la llista chs
    self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs) - 1)])
    self.pool = nn.MaxPool2d(2)

  def forward(self, x):
    ftrs = []
    #per cada Block de la llista
    for block in self.enc_blocks:
      #obtenim l'output
      x = block(x)
      #l'afegim a la llista ftrs
      ftrs.append(x)
      #reduïm dimensions
      x = self.pool(x)
    return ftrs

In [5]:
class Decoder(nn.Module):
  #a chs passem les dimensions dels Blocks que tindrem
  def __init__(self, chs = (1024, 512, 256, 128, 64)):
    super().__init__()
    self.chs = chs
    #definim les upconvs com una llista de up-convolucions que anirà tenint els 
    #inputs i outputs i i i+1 definits a chs
    self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs) - 1)])
    #creem una llista de mòduls amb els blocks,
    #les dimensions seràn i i i+1, sent aquestes valors de la llista chs
    self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs) - 1)])

  def forward(self, x, encoder_features):
    #per cada valor dels tamanys
    for i in range(len(self.chs) - 1):
      #fem la up-convolution corresponent
      x = self.upconvs[i](x)
      #retallem el marge dels resultats obtinguts per l'encoder
      # (buit per culpa de la convolució)
      enc_ftrs = self.crop(encoder_features[i], x)
      #concatenem el valor actual amb els features retallats anteriorment
      x = torch.cat([x, enc_ftrs], dim=1)
      #ho passem al Block corresponent del Decoder
      x = self.dec_blocks[i](x)
    return x
  
  def crop(self, enc_ftrs, x):
    #optenim l'altura i amplada de x
    _, _, H, W = x.shape
    #retallem
    enc_ftrs = torchvision.transforms.CenterCrop([H,W])(enc_ftrs)
    return enc_ftrs

In [6]:
class UNet(nn.Module):
  def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
    super().__init__()
    self.encoder = Encoder(enc_chs)
    self.decoder = Decoder(dec_chs)
    self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
    self.retain_dim = retain_dim
    self.out_sz = out_sz

  def forward(self, x):
    enc_ftrs = self.encoder(x)
    out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
    out = self.head(out)
    if self.retain_dim:
      out = nn.functional.interpolate(out, self.out_sz)
    return out

In [7]:
from torch.utils.data import Dataset
from PIL import Image
class MyDataset(Dataset):
    def __init__(self, img_lst, gth_lst):
        self.img_lst = img_lst
        self.gth_lst = gth_lst
    
    def __getitem__(self, index):
        x = Image.open(self.img_lst[index])
        y = Image.open(self.gth_lst[index])

        return {'img':x, 'gth':y}

    def __len__(self):
        return len(self.img_lst)

In [8]:
# Parameters
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 4}
max_epochs = 100

In [9]:
PATH = 'C:/Users/ger-m/Desktop/UNI/4t/TFG/dataset'
pimg = os.listdir(PATH + '/hd')
pgth = os.listdir(PATH + '/sd')
images = [x for x in pimg]
ground = [y for y in pgth]

imtrain = images[:int(len(images)*0.75)]
imtest = images[int(len(images)*0.75):]

gttrain = ground[:int(len(ground)*0.75)]
gttest = ground[int(len(ground)*0.75):]

In [10]:
data = MyDataset(images, ground)

In [11]:
# Generators
training_set = MyDataset(imtrain, gttrain)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = MyDataset(imtest, gttest)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [12]:
unet = UNet()

In [14]:
from torch import optim
criterion = nn.MSELoss()
optimizer = optim.Adam(unet.parameters(), lr=1e-4, weight_decay=1e-4)

In [15]:
def train(epoch, dataloader, model, criterion, optimizer, device):
    model.train()
    lensamples = len(dataloader)
    for i_batch, sample_batched in enumerate(dataloader):
        images = sample_batched['img'].to(device)
        ground = sample_batched['gth'].to(device)
        n_iter = epoch*lensamples + i_batch
        
        output = model(images)
        
        if n_iter%100==0:
            xi = vutils.make_grid(images, normalize=True, scale_each=True)
            xg = vutils.make_grid(ground,  normalize=True, scale_each=True)
            xo = vutils.make_grid(output, normalize=True, scale_each=True)
            x = torch.cat((xi,xg,xo),1)
            writer.add_image('train/output', x, n_iter)

        loss = criterion(output, ground)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar('train/loss', loss.item(), n_iter)
        
        print('Train -> sample/numSamples/epoch: {0}/{1}/{2}, Loss: {3}' \
              .format(i_batch, lensamples, epoch, loss.item()))

In [16]:
for epoch in range(max_epochs):
    # Training
    train(epoch, training_generator, unet, criterion, optimizer, device)          

In [22]:
device

device(type='cuda', index=0)