In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [2]:
preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [3]:
datas = torchvision.datasets.ImageFolder('animal10/',transform=preprocess)

In [4]:
loader = torch.utils.data.DataLoader(datas,batch_size=8,shuffle=True)

In [6]:
class ConvBlock(nn.Module):
    
    def __init__(self, input_dim, first_dim, second_dim, output_dim):
        super(ConvBlock, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.block = nn.Sequential(
            nn.Conv2d(input_dim,first_dim,1),
            nn.Conv2d(first_dim,second_dim,3,padding=1),
            nn.Conv2d(second_dim,output_dim,1)
        )
        self.downsample = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, 1)
        )
        
    def forward(self, x):
        tmp = self.block(x)
        if self.input_dim != self.output_dim:
            res = self.downsample(x)
        else:
            res = x
        tmp = tmp + res
        return tmp

In [7]:
class ResNet(nn.Module):
    
    def __init__(self):
        super(ResNet, self).__init__()
        self.foot = nn.Sequential(
            nn.Conv2d(3,64,7,stride=2, padding = 5),
            nn.MaxPool2d(3,stride=2)
        )
        self.block1 = nn.Sequential(
            ConvBlock(64,64,64,256),
            ConvBlock(256,64,64,256),
            ConvBlock(256,64,64,256)
        )
        self.block2 = nn.Sequential(
            ConvBlock(256,128,128,512),
            ConvBlock(512,128,128,512),
            ConvBlock(512,128,128,512),
            ConvBlock(512,128,128,512),
            nn.MaxPool2d(2,2)
        )
        self.block3 = nn.Sequential(
            ConvBlock(512,256,256,1024),
            ConvBlock(1024,256,256,1024),
            ConvBlock(1024,256,256,1024),
            ConvBlock(1024,256,256,1024),
            ConvBlock(1024,256,256,1024),
            ConvBlock(1024,256,256,1024),
            nn.MaxPool2d(2,2)
        )
        self.block4 = nn.Sequential(
            ConvBlock(1024,512,512,2048),
            ConvBlock(2048,512,512,2048),
            ConvBlock(2048,512,512,2048),
            nn.AvgPool2d(2,2)
        )
        self.fc = nn.Sequential(
            nn.Linear(2048*7*7,10)
        )

        
    def forward(self, x):
        tmp = self.foot(x)
        tmp = self.block1(tmp)
        tmp = self.block2(tmp)
        tmp = self.block3(tmp)
        tmp = self.block4(tmp)
        tmp = torch.flatten(tmp,start_dim=1)
        tmp = self.fc(tmp)

        return tmp

In [14]:
model = ResNet().cuda()

In [9]:
model(torch.randn(8,3,224,224).cuda()).shape

torch.Size([8, 10])

In [10]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters())

In [11]:
epoches = 1

In [15]:
model.train()
training_loss = []
for e in range(epoches):
    running_loss = 0.0
    for data in tqdm(loader):
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        output = model(img)
        
        optim.zero_grad()
        loss = criterion(output,label)
        loss.backward()
        optim.step()
            
        running_loss += loss.item()
    training_loss.append(running_loss/len(data))
    print(f'epoch {e+1}: {running_loss/len(data)}')

100%|██████████████████████████████████████████████████████████████████████████████| 3273/3273 [23:40<00:00,  2.30it/s]

epoch 1: 3765.3061668872833





In [18]:
model.eval()
with torch.no_grad():
    running_hit = 0.0
    for data in tqdm(loader):
        img, label = data
        img = img.cuda()
        
        output = model(img)
        output = output.detach().cpu()
        output = output.argmax(dim=1)
        acc = torch.sum(label == output)
        running_hit += acc.item()
    print(running_hit/len(datas))

100%|██████████████████████████████████████████████████████████████████████████████| 3273/3273 [05:21<00:00, 10.18it/s]

0.07146949845295848





In [19]:
torch.save(model.state_dict(),'resnet.pth')

In [None]:
model.load_state_dict('resnet.pth')

In [12]:
# Netron 檢查
torch.onnx.export(model.cpu(),torch.randn(1,3,224,224),'a.onnx')