In [8]:
import torch
import torchvision
import matplotlib.pyplot as plt

from torchvision import transforms
from collections import namedtuple
from torchvision.datasets import ImageFolder

import torchvision.models as models

from sklearn.metrics import classification_report
from torch.nn import functional as F

from torch import nn

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Chuẩn bị dữ liệu

In [10]:
TrainTest = namedtuple('TrainTest', ['train', 'test'])

#hàm biến đổi dữ liệu 
def prepare_data():
  transform_train = transforms.Compose([
                                  transforms.Resize((256, 256)),                                        #resize images
                                  transforms.RandomCrop(224, padding=4),                 #lấy vùng ngẫu nhiên trong ảnh
                                  transforms.RandomHorizontalFilp(),                             #lat ngang
                                  transforms.ToTensor(),                                                       #convert to Tensor
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])

  transform_test = transforms.Compose([
                                       transforms.Resize((224, 224)),
                                       transforms.ToTensor(),                           
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
  
  path_data_train = "/content/drive/MyDrive/kì_1_4/Xu_ly_anh/Thay_cuong/gk4/simple_OCR/data/train"
  path_data_test = "/content/drive/MyDrive/kì_1_4/Xu_ly_anh/Thay_cuong/gk4/simple_OCR/data/test"

  trainset = torchvision.datasets.ImageFolder(root=path_data_train, transform=transform_train)
  testset = torchvision.datasets.ImageFolder(root=path_data_test, transform=transform_test)

  return TrainTest(
      train=trainset,
      test=testset
  )

In [11]:
#hàm load dữ liệu thành các batch để xử lý 
def prepare_loader(datasets):
  batch_size = 8
  trainloader = torch.utils.data.DataLoader(
      datasets.train, batch_size=batch_size, shuffle=True, num_workers=2)
  testloader = torch.utils.data.DataLoader(
      datasets.test, batch_size=batch_size, shuffle=False, num_workers=2)
  return TrainTest(
      train=trainloader,
      test=testloader
  )

#Mô hình

In [12]:
class MyNet(nn.Module):
  def __init__(self):
    super().__init__()          #gọi hàm khởi tạo của lớp nn
    self.features = models.resnet18(pretrained=True)
    self.features.fc = torch.nn.Linear(512, 4)        #chuyển thành 4 classes 
  
  def forward(self, x):
    out = self.features(x)
    out = F.log_softmax(out, dim=1)
    return out

#Train

In [13]:
def train_epoch(epoch, model, train_loader, loss_func, optimizer, device):
  model.train()
  running_loss = 0.0  #gía trị hàm loss 
  reporting_step = 8

  #step = 0

  for i, (images, labels) in enumerate(train_loader):
    images, labels = images.to(device), labels.to(device)
    #step += 1

    #không tích luỹ đạo hàm 
    optimizer.zero_grad()
    outputs = model(images)                 
    loss = loss_func(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    #print(f"Reporting_step:{reporting_step} i:{i}")
    if i % reporting_step == reporting_step - 1:
      print("----------------------------------------------------------------------------------------------------")
      print(f"Epoch: {epoch} step: {i} average loss {running_loss / reporting_step:0.4f}")
      running_loss = 0.0

def test_epoch(epoch, model, test_loader, device):
  model.eval()
  y_pred = []
  y_true = []
  
  for i, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs, dim=1)
    y_pred += list(predicted.cpu().numpy())
    y_true += list(labels.cpu().numpy())
  return y_pred, y_true

#loss và optimizer 
def get_trainer(model):
  loss = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  return loss, optimizer

#Hàm main

In [14]:
def main(PATH='./model_simpleOCR.pth'):
  classes = ['highlands', 'others', 'phuclong', 'starbucks']
  datasets = prepare_data()         #convert dữ liệu 

  print("data", len(datasets.train), len(datasets.test))

  loaders =prepare_loader(datasets)
  model_simpleOCR = MyNet()
  loss, optimizer = get_trainer(model_simpleOCR)
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  number_epoch = 20

  print("number class: ", len(classes))

  model_simpleOCR.to(device)

  for epoch in range(number_epoch):
    train_epoch(epoch, model_simpleOCR, loaders.train, loss, optimizer, device)
    y_pred, y_true = test_epoch(epoch, model_simpleOCR, loaders.test, device)

    print(classification_report(y_true, y_pred, target_names=classes))
    torch.save(model_simpleOCR.state_dict(), PATH)
  
  return model_simpleOCR

model_simpleOCR = main()

data 496 100
number class:  4
Epoch: 0 step: 7 average loss 1.4112
Epoch: 0 step: 15 average loss 1.1403
Epoch: 0 step: 23 average loss 0.8828
Epoch: 0 step: 31 average loss 1.6794
Epoch: 0 step: 39 average loss 1.7244
Epoch: 0 step: 47 average loss 1.8562
Epoch: 0 step: 55 average loss 1.6979
              precision    recall  f1-score   support

   highlands       0.51      0.97      0.67        40
      others       0.00      0.00      0.00        18
    phuclong       0.48      0.42      0.45        26
   starbucks       1.00      0.06      0.12        16

    accuracy                           0.51       100
   macro avg       0.50      0.37      0.31       100
weighted avg       0.49      0.51      0.40       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 1 step: 7 average loss 1.0598
Epoch: 1 step: 15 average loss 1.8702
Epoch: 1 step: 23 average loss 2.1647
Epoch: 1 step: 31 average loss 1.8590
Epoch: 1 step: 39 average loss 2.9390
Epoch: 1 step: 47 average loss 2.5448
Epoch: 1 step: 55 average loss 3.2971
              precision    recall  f1-score   support

   highlands       0.33      0.07      0.12        40
      others       0.00      0.00      0.00        18
    phuclong       0.29      0.62      0.40        26
   starbucks       0.08      0.19      0.12        16

    accuracy                           0.22       100
   macro avg       0.18      0.22      0.16       100
weighted avg       0.22      0.22      0.17       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 2 step: 7 average loss 2.9249
Epoch: 2 step: 15 average loss 2.5878
Epoch: 2 step: 23 average loss 2.2939
Epoch: 2 step: 31 average loss 2.4402
Epoch: 2 step: 39 average loss 2.1545
Epoch: 2 step: 47 average loss 2.2033
Epoch: 2 step: 55 average loss 2.3217
              precision    recall  f1-score   support

   highlands       0.68      0.90      0.77        40
      others       0.00      0.00      0.00        18
    phuclong       0.47      0.85      0.60        26
   starbucks       0.00      0.00      0.00        16

    accuracy                           0.58       100
   macro avg       0.29      0.44      0.34       100
weighted avg       0.39      0.58      0.47       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 3 step: 7 average loss 2.1610
Epoch: 3 step: 15 average loss 1.8136
Epoch: 3 step: 23 average loss 1.8298
Epoch: 3 step: 31 average loss 1.5308
Epoch: 3 step: 39 average loss 2.7645
Epoch: 3 step: 47 average loss 2.0644
Epoch: 3 step: 55 average loss 1.8606
              precision    recall  f1-score   support

   highlands       0.65      0.28      0.39        40
      others       0.20      0.44      0.28        18
    phuclong       0.25      0.04      0.07        26
   starbucks       0.13      0.31      0.18        16

    accuracy                           0.25       100
   macro avg       0.31      0.27      0.23       100
weighted avg       0.38      0.25      0.25       100

Epoch: 4 step: 7 average loss 1.3536
Epoch: 4 step: 15 average loss 1.7718
Epoch: 4 step: 23 average loss 1.9793
Epoch: 4 step: 31 average loss 1.4807
Epoch: 4 step: 39 average loss 1.5193
Epoch: 4 step: 47 average loss 1.4624
Epoch: 4 step: 55 average loss 1.1753
              precision    recall  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 5 step: 7 average loss 1.6836
Epoch: 5 step: 15 average loss 1.6846
Epoch: 5 step: 23 average loss 2.1757
Epoch: 5 step: 31 average loss 1.2917
Epoch: 5 step: 39 average loss 1.4765
Epoch: 5 step: 47 average loss 0.9999
Epoch: 5 step: 55 average loss 0.8966
              precision    recall  f1-score   support

   highlands       0.97      0.97      0.97        40
      others       1.00      0.06      0.11        18
    phuclong       0.00      0.00      0.00        26
   starbucks       0.27      1.00      0.43        16

    accuracy                           0.56       100
   macro avg       0.56      0.51      0.38       100
weighted avg       0.61      0.56      0.48       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 6 step: 7 average loss 0.6431
Epoch: 6 step: 15 average loss 0.5511
Epoch: 6 step: 23 average loss 0.8247
Epoch: 6 step: 31 average loss 1.7645
Epoch: 6 step: 39 average loss 1.7266
Epoch: 6 step: 47 average loss 0.7463
Epoch: 6 step: 55 average loss 0.9772
              precision    recall  f1-score   support

   highlands       0.95      0.95      0.95        40
      others       0.94      0.94      0.94        18
    phuclong       0.63      0.85      0.72        26
   starbucks       0.71      0.31      0.43        16

    accuracy                           0.82       100
   macro avg       0.81      0.76      0.76       100
weighted avg       0.83      0.82      0.81       100

Epoch: 7 step: 7 average loss 0.7017
Epoch: 7 step: 15 average loss 0.6042
Epoch: 7 step: 23 average loss 0.1056
Epoch: 7 step: 31 average loss 1.0181
Epoch: 7 step: 39 average loss 0.4318
Epoch: 7 step: 47 average loss 0.5173
Epoch: 7 step: 55 average loss 0.4780
              precision    recall  