# 学習済みモデルの利用

# 学習済み（オブジェクト）・モデルの保存、読み込み
torch.save(model, 'model_weight.pth')
model = torch.load('model_weight.pth')

# TorchScript形式でモデル保存/読み込み
model_scripted = torch.jit.script(model)
model_scripted.save('model_scripted.pth')

model = torch.jit.load('model_scripted.pth')

In [2]:
import pprint
import torch
import torchvision

# 画像分類
pprint.pprint([s for s in dir(torchvision.models) if s[0].isupper()], compact=True)

# セマンティックセグメンテーション
pprint.pprint([s for s in dir(torchvision.models.segmentation) if s[0].isupper()], compact=True)

# 物体検出
pprint.pprint([s for s in dir(torchvision.models.detection) if s[0].isupper()], compact=True)

# 動画分類
pprint.pprint([s for s in dir(torchvision.models.video) if s[0].isupper()], compact=True)

['AlexNet', 'AlexNet_Weights', 'ConvNeXt', 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Tiny_Weights',
 'DenseNet', 'DenseNet121_Weights', 'DenseNet161_Weights',
 'DenseNet169_Weights', 'DenseNet201_Weights', 'EfficientNet',
 'EfficientNet_B0_Weights', 'EfficientNet_B1_Weights',
 'EfficientNet_B2_Weights', 'EfficientNet_B3_Weights',
 'EfficientNet_B4_Weights', 'EfficientNet_B5_Weights',
 'EfficientNet_B6_Weights', 'EfficientNet_B7_Weights',
 'EfficientNet_V2_L_Weights', 'EfficientNet_V2_M_Weights',
 'EfficientNet_V2_S_Weights', 'GoogLeNet', 'GoogLeNetOutputs',
 'GoogLeNet_Weights', 'Inception3', 'InceptionOutputs', 'Inception_V3_Weights',
 'MNASNet', 'MNASNet0_5_Weights', 'MNASNet0_75_Weights', 'MNASNet1_0_Weights',
 'MNASNet1_3_Weights', 'MaxVit', 'MaxVit_T_Weights', 'MobileNetV2',
 'MobileNetV3', 'MobileNet_V2_Weights', 'MobileNet_V3_Large_Weights',
 'MobileNet_V3_Small_Weights', 'RegNet', 'RegNet_X_16GF_Weights',
 'RegNet_X_1_6GF_Weights', 

# 引数pretrainedをTrueとすると学習済みモデルが生成できる。

In [None]:
model = torchvision.models.resnet18(pretrained=True)
# モデルの保存


In [None]:
model2 = torchvision.models.vgg16(pretrained=True)

In [None]:
model_GAN = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN',
                           pretrained=True, useGPU=False)
torch.hub.list(github='facebookresearch/pytorch_GAN_zoo:hub')

In [None]:
torch.hub.help('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN')

In [None]:
pprint.pprint([s for s in dir(torchvision.models) if s[0].isupper()], compact=True)

In [None]:
from torch import utils
from torchvision import datasets
import torchvision.transforms as transforms

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

model_gan = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN',
                           pretrained=True, useGPU=False)

torch.save(model_gan, 'model_gan.pth')

In [1]:
# Download an example image from the pytorch website
import urllib.request

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)

('dog.jpg', <http.client.HTTPMessage at 0x105e95f10>)

In [6]:
# sample execution (requires torchvision)
import torch
from PIL import Image
from torchvision import transforms

model = torch.load('../models/model_vgg16.pth')
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
input_batch = input_batch.to('mps')
model.to('mps')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

tensor([-3.8866e+00, -5.2854e+00, -2.7631e+00, -3.6242e+00, -2.5111e+00,
        -2.2481e+00, -2.7859e+00,  1.4291e+00,  5.3659e+00, -2.2100e+00,
        -5.1515e+00, -3.9944e+00, -4.3563e+00, -3.6430e+00, -3.4027e+00,
        -5.5418e+00, -2.3920e+00, -1.5563e+00, -1.3265e+00, -2.7161e+00,
        -4.1317e+00, -3.1217e+00, -5.5883e-01, -4.2849e-01, -3.2989e+00,
        -4.3049e+00, -4.3274e+00, -3.0898e+00, -2.2745e+00, -2.5528e+00,
        -6.7184e+00, -3.5171e+00, -5.1942e+00, -5.9203e+00, -4.9720e+00,
        -6.1499e+00, -4.0393e+00, -5.2016e+00, -6.6448e+00, -3.9008e+00,
        -2.8444e+00, -5.4609e+00, -6.9220e+00, -6.5001e+00, -5.1454e+00,
        -4.4600e+00, -9.4703e-01, -5.8946e+00, -6.2213e+00, -4.6719e+00,
        -2.7034e+00, -6.4176e+00, -2.9126e+00, -3.8957e+00, -4.7598e+00,
        -3.0818e+00, -4.0203e+00, -5.4063e+00, -4.5474e+00, -3.4727e+00,
        -1.2816e+00, -6.4201e+00, -7.4006e+00, -4.8451e+00, -4.4242e+00,
        -5.5235e+00, -6.1713e+00, -5.4518e+00, -7.0

In [7]:
# Download ImageNet labels
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

--2023-06-10 18:28:20--  https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8003::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10472 (10K) [text/plain]
Saving to: ‘imagenet_classes.txt’


2023-06-10 18:28:20 (67.9 MB/s) - ‘imagenet_classes.txt’ saved [10472/10472]



In [8]:
# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

Samoyed 0.9201600551605225
collie 0.016669986769557
Great Pyrenees 0.01547842938452959
Pomeranian 0.012156737968325615
Border collie 0.011078499257564545
