<a href="https://colab.research.google.com/github/hsyi123/Edge_AI/blob/main/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%cd /content/
# 把資料生成工具 clone 下來
!git clone https://github.com/rachellin0105/Single_char_image_generator.git
%cd Single_char_image_generator


/content
Cloning into 'Single_char_image_generator'...
remote: Enumerating objects: 508, done.[K
remote: Counting objects: 100% (60/60), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 508 (delta 55), reused 47 (delta 47), pack-reused 448[K
Receiving objects: 100% (508/508), 428.80 MiB | 16.08 MiB/s, done.
Resolving deltas: 100% (118/118), done.
/content/Single_char_image_generator


In [None]:
# Single_char_image_generator/chars.txt 是字典，預設有102字，可以在上面增減字。
!head -n 40 chars.txt > temp.txt
!mv temp.txt chars.txt

In [None]:

# 安裝它需要的套件
!python -m pip install -r requirements.txt

# 用一行指令執行生成 
!python OCR_image_generator_single_ch.py --num_per_word=500

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
/content/.caches/a56a73c2d7a54d049026dfa482bfaac9
Save font(./fonts/chinse_jian/simfang.ttf) supported chars(40) to cache
/content/.caches/ceb1594269364fa1c5230afd0053bf61
Save font(./fonts/chinse_jian/2.ttf) supported chars(40) to cache
Start generating...
Saving images in directory : output
 98% 39/40 [23:23<00:35, 35.84s/it]

## 使用Pytorch 訓練ResNet-18

In [None]:
import torch
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

if torch.cuda.is_available():
  device = torch.device('cuda:0')
  print('GPU')
else:
  device = torch.device('cpu')
  print('CPU')


In [None]:
class ChineseCharDataset(Dataset):
  def __init__(self, data_file, root_dir, dict_file):
    # data_file:  標註檔的路徑 (標註檔內容: ImagePath, GroundTruth)
    # root_dir: ImagePath所在的資料夾路徑
    # dict_file: 字典的路徑

    # 使用 pandas 將生成的單字labels.txt當作csv匯入進來
    self.char_dataframe = pd.read_csv(data_file, index_col=False, encoding='utf-8', header=None)
    self.root_dir = root_dir
    with open(dict_file, 'r', encoding='utf-8') as f:
      # 將資料集包含的字集匯入進來
      word_list = [line for line in f.read().split('\n') if line.strip() != '']
      self.dictionary = {word_list[i]: i for i in range(0, len(word_list))}

    print(self.char_dataframe)
    print(self.dictionary)

  def __len__(self):
    return len(self.char_dataframe)

  def __getitem__(self, idx):
    
    # 取得第idx張圖片的path，並將圖片打開
    image_path = os.path.join(self.root_dir, self.char_dataframe.iloc[idx, 0])
    image = Image.open(image_path)
    
    # 取得 Ground Truth 並轉換成數字
    char = self.char_dataframe.iloc[idx, 1]
    char_num = self.dictionary[char]

    
    return (transforms.ToTensor()(image), torch.Tensor([char_num]))
    

In [None]:
%cd /content/

# 宣告好所有要傳入 ChineseCharDataset 的引數
data_file_path = './Single_char_image_generator/output/labels.txt'
root_dir = './Single_char_image_generator/'
dict_file_path = './Single_char_image_generator/chars.txt'

# 模型儲存位置
save_path = './checkpoint.pt'

# 宣告我們自訂的Dataset，把它包到 Dataloader 中以便我們訓練使用
char_dataset = ChineseCharDataset(data_file_path, root_dir, dict_file_path)
char_dataloader = DataLoader(char_dataset, batch_size=64, shuffle=True, num_workers=2)

In [None]:
# --- Training ---

# 我們使用torchvision提供的 ResNet-18 當作我們的AI模型。 
net = models.resnet18(num_classes=40) # num_classes 為類別數量(幾種不一樣的字)
net = net.to(device) # 傳入GPU
net.train()

optimizer = optim.Adam(net.parameters(), lr=0.005)

# 訓練總共Epochs數
epochs = 30


each_loss = []
for i in tqdm(range(1, epochs + 1)):
  losses = 0
  for idx, data in enumerate(char_dataloader):
    image, label = data
    image = image.to(device)
    label = label.squeeze() # 將不同batch壓到同一個dimension
    label = label.to(device, dtype=torch.long)
    
    net.zero_grad()
    result = net(image)

    # 計算損失函數
    loss = F.cross_entropy(result, label)
    losses += loss.item()
    if idx % 10 == 0:  # 每10個batch輸出一次
      print(f'epoch {i}- loss: {loss.item()}')

    # 計算梯度，更新模型參數
    loss.backward()
    optimizer.step()

  avgloss = losses / len(char_dataloader)
  each_loss.append(avgloss)
  print(f'{i}th epoch end. Avg loss: {avgloss}')

# 儲存模型
torch.save({
  'epoch': epochs,
  'model_state_dict': net.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),  
}, save_path)

# 畫出訓練過程圖表 (Y_axis - loss / X_axis - epoch)
plt.plot(each_loss, '-b', label='loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

    

  0%|          | 0/30 [00:00<?, ?it/s]

epoch 1- loss: 3.8148410320281982
epoch 1- loss: 4.187935829162598
epoch 1- loss: 3.8179514408111572
epoch 1- loss: 3.864957809448242
epoch 1- loss: 3.7642159461975098
epoch 1- loss: 3.7038328647613525
epoch 1- loss: 3.7086989879608154
epoch 1- loss: 3.4635119438171387
epoch 1- loss: 3.31970477104187
epoch 1- loss: 3.3028862476348877
epoch 1- loss: 2.905038833618164
epoch 1- loss: 2.6282732486724854
epoch 1- loss: 1.7836042642593384
epoch 1- loss: 1.2610650062561035
epoch 1- loss: 1.1504255533218384
epoch 1- loss: 1.3057258129119873
epoch 1- loss: 1.1957378387451172
epoch 1- loss: 1.2346185445785522
epoch 1- loss: 1.1439452171325684
epoch 1- loss: 1.0852470397949219
epoch 1- loss: 1.0847910642623901
epoch 1- loss: 1.1823923587799072
epoch 1- loss: 1.1600979566574097
epoch 1- loss: 1.0332534313201904
epoch 1- loss: 0.8823889493942261
epoch 1- loss: 1.0011093616485596
epoch 1- loss: 0.8756234645843506
epoch 1- loss: 0.7632589340209961
epoch 1- loss: 0.8814308643341064
epoch 1- loss: 0.98

  3%|▎         | 1/30 [00:25<12:32, 25.96s/it]

epoch 1- loss: 0.9687943458557129
1th epoch end. Avg loss: 2.0006575736755763
epoch 2- loss: 0.863336443901062
epoch 2- loss: 0.9731904864311218
epoch 2- loss: 0.9352869987487793
epoch 2- loss: 0.8585711121559143
epoch 2- loss: 0.8880569338798523
epoch 2- loss: 1.0341097116470337
epoch 2- loss: 1.075272798538208
epoch 2- loss: 0.8755770921707153
epoch 2- loss: 1.000191330909729
epoch 2- loss: 0.8533651232719421
epoch 2- loss: 0.8850147724151611
epoch 2- loss: 0.8267309665679932
epoch 2- loss: 1.009175419807434
epoch 2- loss: 0.8572924733161926
epoch 2- loss: 0.8540958762168884
epoch 2- loss: 0.771943986415863
epoch 2- loss: 0.966766893863678
epoch 2- loss: 0.8897380232810974
epoch 2- loss: 0.8211349844932556
epoch 2- loss: 0.8137085437774658
epoch 2- loss: 1.0870696306228638
epoch 2- loss: 0.8907762765884399
epoch 2- loss: 0.8384740948677063
epoch 2- loss: 0.8597477674484253
epoch 2- loss: 0.7410468459129333
epoch 2- loss: 0.9192219972610474
epoch 2- loss: 0.8509759306907654
epoch 2- l

  7%|▋         | 2/30 [00:38<08:31, 18.25s/it]

epoch 2- loss: 0.7604695558547974
2th epoch end. Avg loss: 0.852564136060282
epoch 3- loss: 0.7477419972419739
epoch 3- loss: 0.9314354658126831
epoch 3- loss: 0.9414159059524536
epoch 3- loss: 0.7702266573905945
epoch 3- loss: 0.869100034236908
epoch 3- loss: 0.7446704506874084
epoch 3- loss: 0.74228435754776
epoch 3- loss: 0.7387475371360779
epoch 3- loss: 0.8270977139472961
epoch 3- loss: 0.7905129790306091
epoch 3- loss: 0.8422255516052246
epoch 3- loss: 0.7704252600669861
epoch 3- loss: 0.7818064093589783
epoch 3- loss: 0.8871384859085083
epoch 3- loss: 0.782004177570343
epoch 3- loss: 0.8047785758972168
epoch 3- loss: 0.8312999606132507
epoch 3- loss: 0.7281315922737122
epoch 3- loss: 0.7828102111816406
epoch 3- loss: 0.7768434882164001
epoch 3- loss: 0.7481347322463989
epoch 3- loss: 0.7364648580551147
epoch 3- loss: 0.7031259536743164
epoch 3- loss: 0.8244256377220154
epoch 3- loss: 0.697899580001831
epoch 3- loss: 0.7831695079803467
epoch 3- loss: 0.7495083212852478
epoch 3- l

 10%|█         | 3/30 [00:51<07:06, 15.81s/it]

epoch 3- loss: 0.7530217170715332
3th epoch end. Avg loss: 0.7883872968701128
epoch 4- loss: 0.7070959806442261
epoch 4- loss: 0.7131262421607971
epoch 4- loss: 0.7795122265815735
epoch 4- loss: 0.6976641416549683
epoch 4- loss: 0.7107453346252441
epoch 4- loss: 0.715610921382904
epoch 4- loss: 0.8611559867858887
epoch 4- loss: 0.8266251683235168
epoch 4- loss: 0.7299543619155884
epoch 4- loss: 0.7827078700065613
epoch 4- loss: 0.7823619246482849
epoch 4- loss: 0.762150228023529
epoch 4- loss: 0.6916908025741577
epoch 4- loss: 0.7340007424354553
epoch 4- loss: 0.6460446715354919
epoch 4- loss: 0.6736921072006226
epoch 4- loss: 0.7227196097373962
epoch 4- loss: 0.7675772309303284
epoch 4- loss: 0.7277911305427551
epoch 4- loss: 0.6997172832489014
epoch 4- loss: 0.79668128490448
epoch 4- loss: 0.7232690453529358
epoch 4- loss: 0.7337512969970703
epoch 4- loss: 0.7992051839828491
epoch 4- loss: 0.8684308528900146
epoch 4- loss: 0.757362961769104
epoch 4- loss: 0.6825699210166931
epoch 4- 

 13%|█▎        | 4/30 [01:04<06:22, 14.71s/it]

epoch 4- loss: 0.7105065584182739
4th epoch end. Avg loss: 0.7574085557041839
epoch 5- loss: 0.7601500749588013
epoch 5- loss: 0.747754693031311
epoch 5- loss: 0.788899838924408
epoch 5- loss: 0.7063006162643433
epoch 5- loss: 0.724290668964386
epoch 5- loss: 0.8026777505874634
epoch 5- loss: 0.719792366027832
epoch 5- loss: 0.8455401659011841
epoch 5- loss: 0.8038150668144226
epoch 5- loss: 0.8285909295082092
epoch 5- loss: 0.6802284717559814
epoch 5- loss: 0.7176029086112976
epoch 5- loss: 0.6803160309791565
epoch 5- loss: 0.7368060350418091
epoch 5- loss: 0.7494210004806519
epoch 5- loss: 0.7551848888397217
epoch 5- loss: 0.6915149092674255
epoch 5- loss: 0.7643218636512756
epoch 5- loss: 0.7175635695457458
epoch 5- loss: 0.7457119226455688
epoch 5- loss: 0.7626882791519165
epoch 5- loss: 0.7528018951416016
epoch 5- loss: 0.7147287130355835
epoch 5- loss: 1.0688910484313965
epoch 5- loss: 0.7493776679039001
epoch 5- loss: 0.8142896294593811
epoch 5- loss: 0.6773440837860107
epoch 5-

 17%|█▋        | 5/30 [01:17<05:52, 14.09s/it]

epoch 5- loss: 0.7494186758995056
5th epoch end. Avg loss: 0.7474013370827745
epoch 6- loss: 0.7862500548362732
epoch 6- loss: 0.7118969559669495
epoch 6- loss: 0.7500959634780884
epoch 6- loss: 0.7793929576873779
epoch 6- loss: 0.7764898538589478
epoch 6- loss: 0.7546265125274658
epoch 6- loss: 0.7696437835693359
epoch 6- loss: 0.7683655619621277
epoch 6- loss: 0.7646577954292297
epoch 6- loss: 0.6984355449676514
epoch 6- loss: 0.7103573083877563
epoch 6- loss: 0.7214754819869995
epoch 6- loss: 0.7862439751625061
epoch 6- loss: 0.6945375204086304
epoch 6- loss: 0.7856948971748352
epoch 6- loss: 0.7351394295692444
epoch 6- loss: 0.7221782207489014
epoch 6- loss: 0.680417001247406
epoch 6- loss: 0.7892569303512573
epoch 6- loss: 0.7685316205024719
epoch 6- loss: 0.6860811114311218
epoch 6- loss: 0.7931084036827087
epoch 6- loss: 0.7381157875061035
epoch 6- loss: 0.7205492258071899
epoch 6- loss: 0.7413592338562012
epoch 6- loss: 0.6313061118125916
epoch 6- loss: 0.6481010913848877
epoch

 20%|██        | 6/30 [01:30<05:29, 13.74s/it]

epoch 6- loss: 0.6928310990333557
6th epoch end. Avg loss: 0.7392893995339878
epoch 7- loss: 0.6954134106636047
epoch 7- loss: 0.7487210631370544
epoch 7- loss: 0.7506810426712036


In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms

# 定义计算准确率函数
def compute_accuracy(model, test_dataset):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataset:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += 1
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

# 计算准确率
accuracy = compute_accuracy(net, char_dataloader)
print("Accuracy: {:.2f}%".format(accuracy * 100))

In [None]:
# 创建一个与原始模型相同结构的模型实例
net = models.resnet18(num_classes=40)

# 加载保存的权重
checkpoint = torch.load(save_path)
net.load_state_dict(checkpoint['model_state_dict'])

# 设置模型为评估模式
net.eval()

# 准备待推理的图像
image_path = './Single_char_image_generator/output/img_0000014.jpg'
image = Image.open(image_path)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_image = transform(image).unsqueeze(0)

# 使用模型进行推理
with torch.no_grad():
    outputs = net(input_image)

# 获取预测结果
_, predicted = torch.max(outputs, 1)
prediction = predicted.item()

# 输出预测结果
print('Prediction:', prediction)


In [None]:
# 定义类别标签映射
class_labels = ['肉','古','幼','酥','成','傢','婦','汎','貨','理','男','大','老','樹','民','鴻','禾','髮','酒','麗','鹽','容','由','寵','中','速','食','汽','子','院','批','洗','素','我','快','雞','出','動','品','活']  # 替换为你的实际类别标签

# 获取预测结果
_, predicted = torch.max(outputs, 1)
prediction = predicted.item()

# 根据类别索引获取类别标签
predicted_label = class_labels[prediction-1]

image = Image.open(image_path)

plt.imshow(image)
plt.axis('off')
plt.show()
# 输出预测结果
print('Prediction:', predicted_label)
