<a href="https://colab.research.google.com/github/wzjcaf/colab/blob/main/fangzi6_14.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import torch
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
#from torchinfo import summary

In [73]:
epochs = 10

In [74]:
data_transform = transforms.Compose([transforms.Resize([50,50]),transforms.ToTensor()])

In [75]:
class Mydataset(Dataset):
  def __init__(self,root,img_data,label_data,transform):
    self.root_dir = root
    self.img_dir = img_data
    self.label_dir = label_data
    self.img_path = os.path.join(self.root_dir,self.img_dir)
    self.label_path = os.path.join(self.root_dir,self.label_dir)
    self.img_list = os.listdir(self.img_path)
    self.label_list = os.listdir(self.label_path)
    self.transforms = transform
    self.img_list.sort()
    self.label_list.sort()
  def __getitem__(self,idx):
    img_name = self.img_list[idx]
    label_name = self.label_list[idx]
    img_idx_path = os.path.join(self.root_dir,self.img_dir,img_name)
    label_idx_path = os.path.join(self.root_dir,self.label_dir,label_name)
    #图像
    img = Image.open(img_idx_path)
    img = self.transforms(img)
    #标签
    label = Image.open(label_idx_path)
    label = self.transforms(label)
    label = label.squeeze(0)
    label = label.to(torch.int64)
    return img,label
  def __len__(self):
    #assert len(self.img_list) == len(self.label_list)
    return len(self.img_list)

In [76]:
train_root = "/content/drive/MyDrive/data/train"
test_root = "/content/drive/MyDrive/data/test"
lab = "lab"
img = "img"

In [77]:
traindata = Mydataset(train_root,img,lab,data_transform)
testdata = Mydataset(test_root,img,lab,data_transform)

In [78]:
batch_size=4
train_dataset = DataLoader(traindata,batch_size,shuffle=True)
test_dataset = DataLoader(testdata,batch_size,shuffle=True)

In [79]:
class DoubleConv(nn.Module):
    

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
model = UNet(3,2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
#summary(model,(batch_size,3,32,32))

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [80]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.8, momentum=0.5)

In [81]:
def train(epoch):
    model.train()
    train_loss = 0
    for img, label in train_dataset:
      img = img.to(device)
      #label = F.one_hot(label,num_classes=2)
      #label (batch_size,256,256)
      #label = torch.permute(label, (0,3,1,2))
      #label = label.float()
      label = label.to(device)
      #print(label.shape)
      # Forward pass
      output = model(img)
      #print(output.shape)
      loss = criterion(output, label)
        
      # Backward and optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      train_loss += loss.item()*img.size(0)
    train_loss = train_loss/len(train_dataset.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

In [82]:
def val(epoch):
  model.eval()
  running_loss = 0.0
  running_correct = 0.0
  with torch.no_grad():
        for img,label in test_dataset:
          img = img.to(device)
          label = label.to(device)
          output = model(img)
          #print(label.shape,output.shape)
          loss = criterion(output,label)
          running_loss += loss.item()*img.size(0)
          running_correct += torch.sum(torch.argmax(output.data,1)==label.data)/(256*256)
        epoch_loss = running_loss/len(test_dataset.dataset)
        epoch_correct = running_correct/len(test_dataset.dataset)
        print('Epoch: {} \tval loss:{:.6f}  acc:{:.6f}'.format(epoch,epoch_loss,epoch_correct))

In [None]:
for epoch in range(1, epochs+1):
  train(epoch)
  val(epoch)

Epoch: 1 	Training Loss: 0.061666
Epoch: 1 	val loss:0.002029  acc:0.038147
Epoch: 2 	Training Loss: 0.000001
Epoch: 2 	val loss:0.000055  acc:0.038147
Epoch: 3 	Training Loss: 0.000001
Epoch: 3 	val loss:0.000004  acc:0.038147
Epoch: 4 	Training Loss: 0.000001
Epoch: 4 	val loss:0.000002  acc:0.038147
