[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mifumo081a/pytorch_template/blob/main/examples/notebooks/mnist.ipynb)


In [None]:
# !git clone https://github.com/mifumo081a/pytorch_template.git
# !ls

In [None]:
# !pip install -r pytorch_template/requirements.txt

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms, datasets
import os
import torchinfo

In [None]:
os.chdir("G:\マイドライブ\pytorch_template")

In [None]:
root = os.getcwd()
print(root)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN =True

In [None]:
logs_root = os.path.join(root, "logs/", "mnist/")
os.makedirs(logs_root, exist_ok=True)

In [None]:
trainval_dataset = datasets.MNIST(os.path.join(root, "data/"), train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(os.path.join(root, "data/"), train=False, download=True, transform=transforms.ToTensor())

trainval_dataset.labels = [str(i) for i in range(10)]
test_dataset.labels  = [str(i) for i in range(10)]

train_transforms = transforms.Compose([
                       transforms.ToTensor()
                       ])
val_transforms = transforms.Compose([
                       transforms.ToTensor()
                       ])

test_transforms = transforms.Compose([
                       transforms.ToTensor()
                       ])

test_dataset.transform = test_transforms

In [None]:
import pytorch_template.cross_validation as cv

In [None]:
num_workers = os.cpu_count()
n_splits = 5
batch_size = 50

# kfoldのデータセット及びデータローダーを持つオブジェクトを定義
kfold = cv.KFold_Dataset(n_splits=n_splits, dataset=trainval_dataset)

kfold.get_datasets(train_transforms=train_transforms, val_transforms=val_transforms,
                   shuffle=True, random_state=None)
# kfold.get_datasets(train_transforms=train_transforms, val_transforms=val_transforms,
#                    load_pickle=True)

kfold.get_dataloaders(batch_size=batch_size, num_workers=num_workers)

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=num_workers)

# Model initialize and Train

In [None]:
from pytorch_template.models import GradCAM_Model

In [None]:
class Model(GradCAM_Model):
    def __init__(self):
        super().__init__()
        self.features = torchvision.models.vgg16(pretrained=True).features[:10]
        for layer in self.features:
            for param in layer.parameters():
                param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Linear(128, 10),
            nn.Sigmoid(),
        )
        
    def get_features(self, x):
        x = x.repeat(1, 3, 1, 1)
        return self.features(x)
    
    def get_outputs(self, x):
        x = x.mean([2, 3])
        return self.classifier(x)
        
    def forward(self,x):
        x = self.get_features(x)
        x = self.get_outputs(x)
        return x

torchinfo.summary(Model(), (1, 1, 28, 28))

In [None]:
from pytorch_template.utils import get_model, set_model
from pytorch_template.trainer import Trainer_Classifier

In [None]:
model_root = os.path.join(root, "models/")
print(model_root)
model = Model()

In [None]:
if TRAIN:
  trainer_list = []
  for k in range(n_splits):
    trainer = Trainer_Classifier(
                                 model=model,
                                 device=device, dataloaders=kfold.dataloaders[k],
                                 optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
                                 epochs=2)
    trainer_list.append(trainer)

In [None]:
if TRAIN:
    cv.kfold_train(save_model_root=model_root, n_splits=n_splits,
                   trainer_list=trainer_list)

In [None]:
if TRAIN:
    for k in range(n_splits):
        trainer_list[k].show_curve(logs_root=os.path.join
                                   (logs_root, "curves/"), fname=str(k), save=False)

# Evaluate models

In [None]:
from pytorch_template.validator import ImageClassifier_Validator

In [None]:
evaluator_list = []
for k in range(n_splits):
    model_path = os.path.join(model_root, str(k)+"/")
    evaluator_list.append(ImageClassifier_Validator(model=get_model("model_fit", model_path, device),
                                          device=device,
                                          dataloaders=kfold.dataloaders[k]["val"],
                                          logs_root=logs_root
                                         )
                         )

In [None]:
kfold_acc = []
for k in range(n_splits):
    evaluator_list[k].confusion_matrix(folder_name="confusion_matrix", fname=str(k), save=False)
    kfold_acc.append(evaluator_list[k].acc)

In [None]:
print("Acc mean: {:.4f}, std: {:.4f}".format(np.array(kfold_acc).mean(), np.array(kfold_acc).std()))

In [None]:
for k in range(n_splits):
    evaluator_list[k].show_scores(folder_name="eval", fname=str(k), save=False)

In [None]:
for k in range(n_splits):
    evaluator_list[k].show_cam(folder_name="cam", fname=str(k), save=False)