<a href="https://colab.research.google.com/github/wzjcaf/colab/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [60]:
import torch
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
from torchinfo import summary

In [61]:
data_transform = transforms.Compose([transforms.Resize([32,32]),transforms.ToTensor()])

In [62]:
class Mydataset(Dataset):
  def __init__(self,root,img_data,label_data,transform):
    self.root_dir = root
    self.img_dir = img_data
    self.label_dir = label_data
    self.img_path = os.path.join(self.root_dir,self.img_dir)
    self.label_path = os.path.join(self.root_dir,self.label_dir)
    self.img_list = os.listdir(self.img_path)
    self.label_list = os.listdir(self.label_path)
    self.transforms = transform
    self.img_list.sort()
    self.label_list.sort()
  def __getitem__(self,idx):
    img_name = self.img_list[idx]
    label_name = self.label_list[idx]
    img_idx_path = os.path.join(self.root_dir,self.img_dir,img_name)
    label_idx_path = os.path.join(self.root_dir,self.label_dir,label_name)
    #图像
    img = Image.open(img_idx_path)
    img = self.transforms(img)
    #标签
    label = Image.open(label_idx_path)
    label = self.transforms(label)
    return img,label
  def __len__(self):
    #assert len(self.img_list) == len(self.label_list)
    return len(self.img_list)

In [63]:
train_root = "/content/drive/MyDrive/data/train"
test_root = "/content/drive/MyDrive/data/test"
lab = "lab"
img = "img"

In [64]:
traindata = Mydataset(train_root,img,lab,data_transform)
testdata = Mydataset(test_root,img,lab,data_transform)

In [65]:
batch_size=4
train_dataset = DataLoader(traindata,batch_size,shuffle=True)
test_dataset = DataLoader(testdata,batch_size,shuffle=True)

In [66]:
class DoubleConv(nn.Module):
    

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
model = UNet(3,1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
summary(model,(batch_size,3,32,32))

Layer (type:depth-idx)                        Output Shape              Param #
UNet                                          [4, 1, 32, 32]            --
├─DoubleConv: 1-1                             [4, 64, 32, 32]           --
│    └─Sequential: 2-1                        [4, 64, 32, 32]           --
│    │    └─Conv2d: 3-1                       [4, 64, 32, 32]           1,728
│    │    └─BatchNorm2d: 3-2                  [4, 64, 32, 32]           128
│    │    └─ReLU: 3-3                         [4, 64, 32, 32]           --
│    │    └─Conv2d: 3-4                       [4, 64, 32, 32]           36,864
│    │    └─BatchNorm2d: 3-5                  [4, 64, 32, 32]           128
│    │    └─ReLU: 3-6                         [4, 64, 32, 32]           --
├─Down: 1-2                                   [4, 128, 16, 16]          --
│    └─Sequential: 2-2                        [4, 128, 16, 16]          --
│    │    └─MaxPool2d: 3-7                    [4, 64, 16, 16]           --
│    │    └

In [67]:
class DiceLoss(nn.Module):
    def __init__(self,weight=None,size_average=True):
        super(DiceLoss,self).__init__()
        
    def forward(self,inputs,targets,smooth=1):
        inputs = torch.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice

In [68]:
criterion = DiceLoss()
optimizer = optim.SGD(model.parameters(), lr=0.8, momentum=0.5)

In [69]:
total_step = len(train_dataset)
num_epochs = 2
for epoch in range(num_epochs):
    for i, (img, label) in enumerate(train_dataset):
        img = img.to(device)
        label = label.to(device)

        # Forward pass
        outputs = model(img)
        
        loss = criterion(outputs, label)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 1 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [1/2], Step [1/12], Loss: 0.9924
Epoch [1/2], Step [2/12], Loss: 0.9904
Epoch [1/2], Step [3/12], Loss: 0.9923
Epoch [1/2], Step [4/12], Loss: 0.9860
Epoch [1/2], Step [5/12], Loss: 0.9882
Epoch [1/2], Step [6/12], Loss: 0.9842
Epoch [1/2], Step [7/12], Loss: 0.9921
Epoch [1/2], Step [8/12], Loss: 0.9861
Epoch [1/2], Step [9/12], Loss: 0.9850
Epoch [1/2], Step [10/12], Loss: 0.9859
Epoch [1/2], Step [11/12], Loss: 0.9896
Epoch [1/2], Step [12/12], Loss: 0.9899
Epoch [2/2], Step [1/12], Loss: 0.9869
Epoch [2/2], Step [2/12], Loss: 0.9912
Epoch [2/2], Step [3/12], Loss: 0.9866
Epoch [2/2], Step [4/12], Loss: 0.9881
Epoch [2/2], Step [5/12], Loss: 0.9845
Epoch [2/2], Step [6/12], Loss: 0.9863
Epoch [2/2], Step [7/12], Loss: 0.9812
Epoch [2/2], Step [8/12], Loss: 0.9847
Epoch [2/2], Step [9/12], Loss: 0.9886
Epoch [2/2], Step [10/12], Loss: 0.9847
Epoch [2/2], Step [11/12], Loss: 0.9837
Epoch [2/2], Step [12/12], Loss: 0.9835


In [70]:
# Test
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for img, label in test_dataset:
        img = img.to(device)
        label = label.to(device)
        outputs = model(img)
        _, predicted = torch.max(outputs.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

Test Accuracy of the model on the 10000 test images: 289152.17391304346 %
