In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [2]:
preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [4]:
datas = torchvision.datasets.ImageFolder('D://research/pytorch-implementations/data/animal10/',transform=preprocess)

In [5]:
loader = torch.utils.data.DataLoader(datas,batch_size=8,shuffle=True)

In [5]:
# model = torchvision.models.resnet50()

In [6]:
class ResBlock(nn.Module):
    
    def __init__(self, input_dim, block_dim):
        super(ResBlock, self).__init__()
        self.input_dim = input_dim
        self.output_dim = block_dim*4
        
        self.block = nn.Sequential(
            nn.Conv2d(input_dim,block_dim,1,bias=False),
            nn.Conv2d(block_dim,block_dim,3,padding=1,bias=False),
            nn.Conv2d(block_dim,self.output_dim,1,bias=False)
        )
        self.downsample = nn.Sequential(
            nn.Conv2d(input_dim, self.output_dim, 1,bias=False)
        )
        
    def forward(self, x):
        tmp = self.block(x)
        if self.input_dim != self.output_dim:
            res = self.downsample(x)
        else:
            res = x
        tmp = tmp + res
        return tmp

In [7]:
class ResNet50(nn.Module):
    
    def __init__(self):
        super(ResNet50, self).__init__()
        self.foot = nn.Sequential(
            nn.Conv2d(3,64,7,stride=2, padding = 5),
            nn.MaxPool2d(3,stride=2)
        )
        self.block1 = nn.Sequential(
            ResBlock(64,64),
            ResBlock(256,64),
            ResBlock(256,64)
        )
        self.block2 = nn.Sequential(
            ResBlock(256,128),
            ResBlock(512,128),
            ResBlock(512,128),
            ResBlock(512,128),
            nn.MaxPool2d(2,2)
        )
        self.block3 = nn.Sequential(
            ResBlock(512,256),
            ResBlock(1024,256),
            ResBlock(1024,256),
            ResBlock(1024,256),
            ResBlock(1024,256),
            ResBlock(1024,256),
            nn.MaxPool2d(2,2)
        )
        self.block4 = nn.Sequential(
            ResBlock(1024,512),
            ResBlock(2048,512),
            ResBlock(2048,512),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.fc = nn.Sequential(
            nn.Linear(2048,10)
        )

        
    def forward(self, x):
        tmp = self.foot(x)
        tmp = self.block1(tmp)
        tmp = self.block2(tmp)
        tmp = self.block3(tmp)
        tmp = self.block4(tmp)
        tmp = torch.flatten(tmp,start_dim=1)
        tmp = self.fc(tmp)

        return tmp

In [8]:
model = ResNet50().cuda()

In [9]:
model(torch.randn(8,3,224,224).cuda()).shape

torch.Size([8, 10])

In [10]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters())

In [11]:
epoches = 1

In [None]:
model.train()
training_loss = []
for e in range(epoches):
    running_loss = 0.0
    for data in tqdm(loader):
        img, label = data
        img = img.cuda()
        label = label.cuda()
        
        output = model(img)
        
        optim.zero_grad()
        loss = criterion(output,label)
        loss.backward()
        optim.step()
            
        running_loss += loss.item()
    training_loss.append(running_loss/len(datas))
    print(f'epoch {e+1}: {running_loss/len(datas)}')

In [20]:
model.eval()
with torch.no_grad():
    running_hit = 0.0
    for data in tqdm(loader):
        img, label = data
        img = img.cuda()
        
        output = model(img)
        output = output.detach().cpu()
        output = output.argmax(dim=1)
        acc = torch.sum(label == output)
        running_hit += acc.item()
    print(running_hit/len(datas))

100%|██████████████████████████████████████████████████████████████████████████████| 3273/3273 [04:59<00:00, 10.93it/s]

0.16413919553840864





In [21]:
torch.save(model.state_dict(),'resnet.pth')

In [12]:
model.load_state_dict(torch.load('resnet.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [22]:
# Netron 檢查
torch.onnx.export(model.cpu(),torch.randn(1,3,224,224),'a.onnx')