# 图像分类——ResNet

本范例我们使用ResNet实现对小猫和小狗图片的图像分类。

In [None]:
import numpy as np
import pandas as pd 
from matplotlib import pyplot as plt
from PIL import Image 

import torch
from torch import nn
import torchvision
from torchvision import datasets, models, transforms

import datetime
import os
import copy
import shutil
from pathlib import Path


In [None]:
from argparse import Namespace

config = Namespace(
    img_size = 256, 
    lr = 1e-3,
    batch_size = 32,
)


In [None]:
import sys 
sys.path.insert(0,'../../torchkeras')

from torchkeras.models import ResNet50 


## 一，准备数据

In [None]:
data_url = 'https://github.com/lyhue1991/torchkeras/releases/download/v3.7.2/cats_vs_dogs.zip'
data_file = 'cats_vs_dogs.zip'

if not os.path.exists(data_file):
    torch.hub.download_url_to_file(data_url,data_file)
    shutil.unpack_archive(data_file,'datasets')
    

In [None]:

data_path = './datasets/cats_vs_dogs'

train_cats = os.listdir(os.path.join(data_path,"train","cats"))
img = Image.open(os.path.join(os.path.join(data_path,"train","cats",train_cats[0])))
img 

In [None]:
train_dogs = os.listdir(os.path.join(data_path,"train","dogs"))
img = Image.open(os.path.join(os.path.join(data_path,"train","dogs",train_dogs[0])))
img 


In [None]:
transforms_train = transforms.Compose([
        transforms.Resize((config.img_size,config.img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        transforms.Normalizelize
    ])
transforms_val = transforms.Compose([
        transforms.Resize((config.img_size,config.img_size)),
        transforms.ToTensor(),
    ])

def transform_label(x):
    return  torch.tensor([x],dtype = torch.float)
    
    
ds_train = datasets.ImageFolder(os.path.join(data_path,"train"),transforms_train,
            target_transform = transform_label)
ds_val = datasets.ImageFolder(os.path.join(data_path,"val"),transforms_val,
          target_transform = transform_label)

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=config.batch_size,
                                             shuffle=True)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size,
                                             shuffle=False)

class_names = ds_train.classes

print(len(ds_train))
print(len(ds_val))


In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

plt.figure(figsize=(8,8)) 
idxes = np.random.choice(range(len(ds_train)),9)

for i in range(9):
    idx = idxes[i]
    img,label = ds_train[idx]
    img = img.permute(1,2,0)
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title(f"{class_names[label.int().item()]}")
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()



## 二，构建模型

In [None]:
for features,labels in dl_train:
    break 

In [None]:
from torchkeras import summary 

net = ResNet50(num_classes=1)

summary(net,input_data=features);


## 三，训练模型

In [None]:
from torchkeras import KerasModel 
from torchmetrics import Accuracy


loss_fn = nn.BCEWithLogitsLoss()
metrics_dict = {"acc":Accuracy(task='binary')}

optimizer = torch.optim.SGD(net.parameters(),
                            lr=config.lr, momentum=0.9)

model = KerasModel(net,
                   loss_fn = loss_fn,
                   metrics_dict= metrics_dict,
                   optimizer = optimizer
                  )     



In [None]:
from torchkeras.kerascallbacks import WandbCallback



dfhistory=model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=6, 
                    ckpt_path='checkpoint.pt',
                    patience=2, 
                    monitor="val_acc",
                    mode="max",
                    mixed_precision='no',
                    plot = True
                   )



## 四，评估模型

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import matplotlib.pyplot as plt
def plot_metric(dfhistory, metric):
    train_metrics = dfhistory["train_"+metric]
    val_metrics = dfhistory['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()
    

In [None]:
plot_metric(dfhistory,"loss")

In [None]:
plot_metric(dfhistory,"acc")

In [None]:
model.evaluate(dl_val)

## 五，使用模型

## 六，保存模型


