In [3]:
import numpy as np
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch import nn
import cv2 as cv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm.auto import tqdm
import imageio
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter 

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.block=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.block(x)

In [5]:
def copy_and_crop(down_1layer,up_1layer):
    b,ch,h,w=up_1layer.shape
    crop=T.CenterCrop((h,w))(down_1layer)
    return crop

In [6]:
class Unet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Unet,self).__init__()

        self.encoder=nn.ModuleList([
            DoubleConv(in_channels,64),
            DoubleConv(64,128),
            DoubleConv(128,256),
            DoubleConv(256,512)
        ])
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
        self.bottle_neck=DoubleConv(512,1024)

        self.up_samples=nn.ModuleList([
            nn.ConvTranspose2d(1024,512,kernel_size=2,stride=2),
            nn.ConvTranspose2d(512,256,kernel_size=2,stride=2),
            nn.ConvTranspose2d(256,128,kernel_size=2,stride=2),
            nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
        ])
        self.decoder=nn.ModuleList([
            DoubleConv(1024,512),
            DoubleConv(512,256),
            DoubleConv(256,128),
            DoubleConv(128,64)
        ])
        self.final_1layer=nn.Conv2d(64,out_channels,kernel_size=1,stride=1)

    def forward(self,x):
        skip_connections=[]

        for layer in self.encoder:
            x=layer(x)
            skip_connections.append(x)
            x=self.pool(x)
        x=self.bottle_neck(x)

        for ind,layer in enumerate(self.decoder):
            x=self.up_samples[ind](x)
            y=copy_and_crop(skip_connections.pop(),x)
            x=layer(torch.cat([y,x],dim=1))
        x=self.final_1layer(x)

        return x


In [7]:
model=Unet(1,1)


In [14]:
pics=torch.rand((1,1,256,256))
output=model(pics)

In [15]:
writer=SummaryWriter('runs/U-net')
writer.add_graph(model,pics)

TypeError: type Tensor doesn't define __round__ method