<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils import DataLoader



In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_chans, out_chans, hidden_chans=None):
    super(DownBlock, self).__init__()
    if not hidden_chans:
      hidden_chans = (in_chans + out_chans) // 2
    self.conv = nn.Sequential(
        nn.Conv2d(in_chans, hidden_chans, 3, 1),
        nn.BatchNorm2d(hidden_chans),
        nn.ReLU(True),
        nn.Conv2d(hidden_chans, out_chans, 3, 1),
        nn.BatchNorm2d(out_chans),
        nn.ReLU(True)
    )

  def forward(self, x):
    out = self.conv(x)
    return out

class DownBlock(nn.Module):
  def __init__(self, in_chans, out_chans):
    super(DownBlock, self).__init__()
    if not hidden_chans:
      hidden_chans = (in_chans + out_chans) // 2
    self.conv = DoubleConv(in_chans, out_chans)
    self.down = nn.MaxPool2d(2, 2)

  def forward(self, x):
    out = self.conv(x)
    out = self.down(out)
    return out


class UpBlock(nn.Module):
  def __init__(self, in_chans, out_chans, bilinear=True):
    super(UpBlock, self).__init__()
    if bilinear:
      self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
    else:
      self.up = nn.ConvTranspose2d(in_chans, in_chans//2, 2, 2)
    self.conv = DoubleConv(in_chans, out_chans)


  def forward(self, x1, x2):
    x1 = self.up(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) #pad 0
    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)
