In [1]:
import numpy as np
import torch
import torchvision
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F


In [2]:
class TransformDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        super().__init__()
        self.dataset = dataset
        self.transform = transform
        
        
    def __getitem__(self, n):
        data = self.dataset[n]
        return self.transform(data[0]), data[1]
    
    def __len__(self):
        return len(self.dataset)
        

In [3]:
folder = r'C:\Users\chtti\Downloads\拍照簽收圖檔-20211027T020311Z-001\拍照簽收圖檔'
BATCH_SIZE = 32

train_transform = T.Compose([
    T.AutoAugment(T.autoaugment.AutoAugmentPolicy.IMAGENET),
    T.Resize((540, 540)),
    T.ToTensor()
])
val_transform = T.Compose([
    T.Resize((540, 540)),
    T.ToTensor()
])

dataset = ImageFolder(folder)

train_length = int(len(dataset) * 0.8)
val_length = len(dataset) - train_length
train_set, val_set = random_split(dataset, [train_length, val_length], torch.Generator().manual_seed(42))
train_set_aug = TransformDataset(train_set, train_transform)
val_set_trans = TransformDataset(val_set, val_transform)

train_loader = DataLoader(train_set_aug, batch_size=BATCH_SIZE, shuffle=True,
                          pin_memory=True)
val_loader = DataLoader(val_set_trans, batch_size=8)

In [4]:
len(dataset)

1006

In [5]:
dataset.classes

['bad', 'good']

In [6]:
train_set[0]

(<PIL.Image.Image image mode=RGB size=1080x1920 at 0x146AAAFA4F0>, 0)

In [7]:
# next(iter(train_loader))

In [8]:
plt.rcParams["savefig.bbox"] = 'tight'
plt.rcParams['figure.figsize'] = [24, 8]


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])



In [9]:
# sample = next(iter(train_loader))
# grid = make_grid(sample[0])
# show(grid)
# print(sample[1])

In [10]:
# sample = next(iter(val_loader))
# print(sample[0].shape)
# grid = make_grid(sample[0])
# show(grid)
# print(sample[1])

In [11]:
resnet = torchvision.models.resnet18(pretrained=True)

In [12]:
net = torchvision.models.resnet18(num_classes=2)

In [13]:
d = resnet.state_dict()
del d['fc.weight']
del d['fc.bias']

net.load_state_dict(d, strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net = net.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(net.parameters(), lr=1e-5)

In [15]:
MAX_EPOCH = 20000

for epoch in range(MAX_EPOCH):
    net.train()
    for i, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label = label.to(device)

        pred = net(data)
        loss = loss_fn(pred, label)

        optim.zero_grad()
        loss.backward()
        optim.step()
        
    if epoch % 1 == 0:
        net.eval()
        with torch.no_grad():
            val_loss, correct = 0, 0
            for i, (data, label) in enumerate(val_loader):
                data = data.to(device)
                label = label.to(device)
                
                pred = net(data)
                val_loss += loss_fn(pred, label).item()
                correct += (pred.argmax(1) == label).type(torch.float).sum().item()
                
        
        val_loss /= len(val_loader)
        correct /= len(val_loader.dataset)
        print('Batch: {}, Loss: {:>5f}, TestLoss: {:>5f}, Acc: {:>2f}'.format(
            epoch, loss.item(), val_loss, correct))

Batch: 0, Loss: 0.494035, TestLoss: 0.666929, Acc: 0.549505
Batch: 1, Loss: 0.767981, TestLoss: 0.677573, Acc: 0.529703
Batch: 2, Loss: 0.588207, TestLoss: 0.677411, Acc: 0.529703
Batch: 3, Loss: 0.578859, TestLoss: 0.680105, Acc: 0.529703
Batch: 4, Loss: 0.614245, TestLoss: 0.670171, Acc: 0.539604
Batch: 5, Loss: 0.688068, TestLoss: 0.675669, Acc: 0.529703
Batch: 6, Loss: 0.781302, TestLoss: 0.669249, Acc: 0.539604
Batch: 7, Loss: 0.504068, TestLoss: 0.665330, Acc: 0.579208
Batch: 8, Loss: 0.754653, TestLoss: 0.665687, Acc: 0.549505
Batch: 9, Loss: 0.810816, TestLoss: 0.664055, Acc: 0.544554
Batch: 10, Loss: 0.776719, TestLoss: 0.657838, Acc: 0.584158
Batch: 11, Loss: 0.842552, TestLoss: 0.661281, Acc: 0.554455
Batch: 12, Loss: 0.712167, TestLoss: 0.658559, Acc: 0.589109
Batch: 13, Loss: 0.616448, TestLoss: 0.655595, Acc: 0.599010
Batch: 14, Loss: 0.774644, TestLoss: 0.656756, Acc: 0.599010
Batch: 15, Loss: 0.814387, TestLoss: 0.653264, Acc: 0.618812
Batch: 16, Loss: 0.630168, TestLos

KeyboardInterrupt: 

In [16]:
net.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [17]:
torch.save(net, 'model_211201.pt')

In [3]:
cls_name = ['bad', 'good']

In [2]:
cls_name = dataset.classes

NameError: name 'dataset' is not defined

In [4]:
import torch


In [13]:
net = torch.load('model_211201.pt')

In [42]:
import mlflow

In [49]:
import image_pyfunc

from importlib import reload
image_pyfunc = reload(image_pyfunc)

In [50]:
mlflow.pytorch.get_default_conda_env()



{'name': 'mlflow-env',
 'channels': ['conda-forge'],
 'dependencies': ['python=3.9.7',
  'pip',
  {'pip': ['mlflow',
    'torch==1.10.0',
    'torchvision==0.11.0a0',
    'cloudpickle==2.0.0']}]}

In [51]:
cls_name

['bad', 'good']

In [52]:
image_pyfunc.save_pytorch_model(net, 'model', (1, 3, 540, 540), cls_name)



In [53]:
# def get_pytorch_env_patch():
#     e = mlflow.pytorch.get_default_conda_env()
#     e['channels'].append('pytorch')
#     e['dependencies'].extend(['pytorch', 'torchvision', 'torchaudio', 'cudatoolkit=11.3'])
#     find_pip = tuple(filter(lambda p: isinstance(p, dict) and 'pip' in p, e['dependencies']))
#     find_torch = tuple(filter(lambda p: 'torch' in p, find_pip[0]['pip']))
#     for p in find_torch:
#         find_pip[0]['pip'].remove(p)
#     return e    
#     
# print(get_pytorch_env_patch())