In [1]:
%matplotlib inline
import torch as t
import os
import numpy as np
from PIL import Image
from torchvision import transforms as T
from torch.utils.data import Dataset
import pandas as pd
import re
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F

In [2]:
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

class DataClassify(Dataset):
    def __init__(self, root, transforms=None):
        #存放图像地址
        self.imgs = [x.path for x in os.scandir(root) if
            x.name.endswith(".jpg")]
        self.labels = [y.path for y in os.scandir(root) if
            y.name.endswith(".txt")]
        self.transforms = transforms
        
    def __getitem__(self, index):
        #读取图像数据并返回
        img_path = self.imgs[index]
        #label = open(self.labels[index]).read()
        label = re.sub('\D','',open(self.labels[index]).read()[-4:])
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
    
#dataset = DataClassify('D:/garbage/try', transforms=transform)
#img, label = dataset[0]
#for img, label in dataset:
#   print(img.size(), label)

In [3]:
class ResBlock(nn.Module):
    #残差块
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.rigth = shortcut
        
    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

In [4]:
class ResNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet, self).__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)
        )
        self.layer1 = self._make_layer(64, 128, 3)
        self.layer2 = self._make_layer(128, 256, 4, stride=2)
        self.layer3 = self._make_layer(256, 512, 6, stride=2)
        self.layer4 = self._make_layer(512, 512, 3, stride=2)
        
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, inchannel, outchannel, block_num, stride=1):
        #构建包含多个残差块的layer
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        layers = []
        layers.append(ResBlock(outchannel, outchannel, stride, shartcut))
        
        for i in range(1, block_num):
            layers.append(ResBlock(outchannel, outchannel))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.pre(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        return self.fc(x)