<a href="https://colab.research.google.com/github/daisuke08253649/DeepLearning/blob/main/illustrationidentification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

In [None]:
import timm
model_list = timm.list_models(pretrained=True)
print(model_list)

In [3]:
import os
import torch
import torchvision
import pandas as pd

from torch import nn
from PIL import Image
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

In [None]:
train_image_dir = './drive/MyDrive/DeepLearning/illustrationdiscrimination/data/train'
val_image_dir = './drive/MyDrive/DeepLearning/illustrationdiscrimination/data/val'

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]),

    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}

train_dataset = torchvision.datasets.ImageFolder(root=train_image_dir, transform=data_transforms['train'])
val_dataset = torchvision.datasets.ImageFolder(root=val_image_dir, transform=data_transforms['val'])

train_dataloader = DataLoader(train_dataset, batch_size=80, shuffle=True)
test_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=True)

print(len(train_dataset))
print(len(val_dataset))

In [None]:
def train(model, train_dataloader, criterion, optimizer, train_losses):
    model.train()
    train_loss = 0

    for inputs, labels in train_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss = train_loss / len(train_dataloader)
    train_losses.append(train_loss)

    return train_loss

def val(model, val_dataloader, criterion, val_losses, val_accuracies):
    model.eval()
    val_loss = 0
    val_accuracy = 0

    for inputs, labels in val_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        val_loss += loss.item()
        val_accuracy += (outputs.argmax(dim=1) == labels).sum().item()

    val_loss = val_loss / len(val_dataloader)
    val_accuracy = val_accuracy / len(val_dataloader.dataset)

    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    return val_loss, val_accuracy


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = timm.create_model('vit_base_patch16_clip_224.openai_ft_in12k_in1k', pretrained=True)
model = model.to(device)

lr = 1e-3
epochs = 50
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []
val_losses = []
val_accuracies = []

In [None]:
for i in range(epochs):
    train_loss = train(model, train_dataloader, criterion, optimizer, train_losses)
    val_loss, val_accuracy = val(model, val_dataloader, criterion, val_losses, val_accuracies)

    print(f'Epoch: {i+1}, Train_loss: {train_loss}, Val_loss: {val_loss}, Val_accuracy: {val_accuracy}')

In [None]:
plt.figure()
plt.plot(range(1, len(train_loss_list)+1), train_loss_list, label='train_loss')
plt.plot(range(1, len(test_loss_list)+1), test_loss_list, label='test_loss')
plt.xlabel('epoch')
plt.legend()
plt.show()

In [None]:
from torch.autograd import Variable

test_iter = iter(test_dataloader)
true_list = []
false_list = []
input, label = next(test_iter)
output = model(Variable(input.cuda()))
_, predict = torch.max(output.data, 1)

for idx in range(len(label)):
    lst = [input[idx], label[idx], predict[idx]]
    if int(label[idx]) == int(predict[idx]):
        true_list.append(lst)
    else:
        false_list.append(lst)

print(f'予測が正解しているデータ:{len(true_list)}')
for idx, tlst in enumerate(true_list[:5]):
    plt.figure(idx+1)
    image = tlst[0].cpu().numpy()
    if image.ndim == 3 and image.shape[0] == 3:
        image = image.transpose((1, 2, 0))
    plt.imshow(image, cmap='Blues')
    plt.title('True: {}, Estim: {}'.format(tlst[1], tlst[2]))

In [None]:
print(f'予測が不正解のデータ:{len(false_list)}')
for idx, flst in enumerate(false_list[:5]):
    plt.figure(idx+1)
    image = flst[0].cpu().numpy()
    if image.ndim == 3 and image.shape[0] == 3:
        image = image.transpose((1, 2, 0))
    plt.imshow(image, cmap='Reds')
    plt.title('True: {}, Estim: {}'.format(flst[1], flst[2]))

In [None]:
#モデル保存
torch.save(model, 'illustrationidentification_model.pth')

In [None]:
#モデルロード
model = torch.load('./保存されたモデルのパス')
model.eval()