In [31]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import os
import numpy as np

In [32]:
num_classes = 12
csv_path = "./class.csv" 

In [33]:
model = models.inception_v3(pretrained=True) #凍結權重
'''for param in model.parameters():
    param.requires_grad = False'''

'for param in model.parameters():\n    param.requires_grad = False'

In [34]:

# 替換 Inception-v3 的最後全連接層
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.train()


Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [35]:
# Inception-v3 模型需要特定的影像輸入格式（299x299 像素，RGB 圖像，並且需要正規化）。使用 torchvision.transforms 進行預處理。
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [36]:

class train_loader():
    
    def __init__(self,path):
        self.imgs=[]
        self.labels=[]
        df = pd.read_csv(path)
        for index, row in df.iterrows():
            image_path = row['path']
            label = row['label']

            if os.path.exists(image_path):
                img = Image.open(image_path)
                img_tensor = preprocess(img)
                self.imgs.append(img_tensor)
                self.labels.append(label)
                #img.show()

                # image processing
                print(f"\rloading image: {image_path} with label: {label}")
            else:
                print(f"Image not found: {image_path}\n")

        self.labels = torch.tensor(self.labels)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [37]:
data_set=train_loader(csv_path)

loading image: ./Dataset/alcohol_tissue/IMG_2555.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2556.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2557.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2558.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2559.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2560.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2561.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2562.JPG with label: 0
loading image: ./Dataset/alcohol_tissue/IMG_2563.JPG with label: 0
loading image: ./Dataset/biscoff/IMG_2755.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_2756.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_2757.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_2758.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_2759.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_2760.jpg with label: 1
loading image: ./Dataset/biscoff/IMG_

In [42]:
model.train()
for epoch in range(num_classes):  # 訓練 10 個 epoch
    print(f"epoch: {epoch}\n")
    for (images, labels) in zip(data_set.imgs,data_set.labels):  # 迭代訓練數據
        labels = torch.tensor([labels])  # 確保 labels 是正確的維度
        images = images.unsqueeze(0)
        data_set.optimizer.zero_grad()

        model.eval()  # 關閉輔助分類器的輸出


        # 進行推理時返回兩個輸出
        outputs = model(images)

        loss = data_set.criterion(outputs, labels)

        loss.backward()

        data_set.optimizer.step()


KeyboardInterrupt: 