## 1. 사용할 패키지 불러오기

In [7]:
import pandas as pd
import cv2
import os
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import json
from data_gen.data_gen import PointInferenceDatasetGenerator, ClassInferenceDatasetGenerator
import torch
import torchvision.models as models
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

## 2. 데이터 불러오기

### (1) Inference Data 경로 지정

In [8]:
inference_img_path = './dataset/test/img'

### (2) Inference Data 불러오기

In [9]:
inference_img_list = os.listdir(inference_img_path)
inference_img_list = sorted(inference_img_list)
inference_img_list = [os.path.join(inference_img_path, img) for img in inference_img_list]

## 3. 모델링

In [10]:
save_path = 'cropped_image/inference'
os.makedirs(save_path, exist_ok=True)

### (1) Dataloader 생성

In [11]:
inference_dataset = PointInferenceDatasetGenerator(inference_img_list)
inference_dataloader = inference_dataset.dataloader()

### (2) 학습된 Point Prediction Model 불러오기

In [12]:
model_name = 'resnet18'
vision_model = models.resnet18(pretrained=True)
num_ftrs = vision_model.fc.in_features
vision_model.fc = nn.Linear(num_ftrs, 8)
vision_model.load_state_dict(torch.load('result/Best_model.pth'))

<All keys matched successfully>

### (3) Point 예측에 대한 Crop

In [13]:
print('Prediction')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
predictions = []
vision_model.to(device)

with torch.no_grad():  
    for data in inference_dataloader['test']:
        images, width, height, original_image, fname = data['image'].float().to(device), data['width'].float(), data['height'].float(), data['original_image'].float(), data['fname']
        images = images.to(device)  
        vision_model.eval()  
        yhat = vision_model(images)  
        pred = yhat.cpu().numpy()
        horizen_min = min(pred[0][0], pred[0][2], pred[0][4], pred[0][6])
        horizen_max = max(pred[0][0], pred[0][2], pred[0][4], pred[0][6])
        vertical_min = min(pred[0][1], pred[0][3], pred[0][5], pred[0][7])
        vertical_max = max(pred[0][1], pred[0][3], pred[0][5], pred[0][7])
        horizen_min = int(horizen_min / 448 * height)
        horizen_max = int(horizen_max / 448 * height)
        vertical_min = int(vertical_min / 224 * width)
        vertical_max = int(vertical_max / 224 * width)
        cropped_image = np.array(original_image)[0, :, :, :][vertical_min:vertical_max, horizen_min:horizen_max, :]
        save_fname = os.path.join(save_path, fname[0])
        cv2.imwrite(save_fname, cropped_image)





Prediction


### (4) Dataloader 생성

In [14]:
inference_dataset = ClassInferenceDatasetGenerator(inference_img_list)
inference_dataloader = inference_dataset.dataloader()

### (5) 학습된 Class Prediction Model 불러오기

In [15]:
model_name = 'resnet18'
vision_model = models.resnet18(pretrained=True)
num_ftrs = vision_model.fc.in_features
uni_label = ['1996_n', '2004_n', '2006_eu', '2006_n', '2006_us', '2019_n', '2019_r', 'bike', 'echo']
vision_model.fc = nn.Linear(num_ftrs, len(uni_label))
vision_model.load_state_dict(torch.load('result/Classification_Best_model.pth'))

<All keys matched successfully>

### (6) Class 예측

In [17]:
print('Test Prediction')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
predictions = []
vision_model.to(device)

with torch.no_grad():  
    for data in inference_dataloader['test']:
        images = data['image'].float().to(device)
        images = images.to(device)  
        vision_model.eval()  
        yhat = vision_model(images)  
        pred = yhat.argmax(dim=1, keepdim = False)
        pred = list(pred.cpu().numpy())
        predictions = predictions + pred

Test Prediction


In [20]:
predictions = [uni_label[i] for i in predictions]
result = pd.DataFrame({'File': inference_img_list, 'Prediction': predictions})

In [21]:
result

Unnamed: 0,File,Prediction
0,./dataset/test/img/01가2636.JPG,2006_n
1,./dataset/test/img/01고9570.JPG,2006_n
2,./dataset/test/img/01구2337.jpg,2004_n
3,./dataset/test/img/01라0185.jpg,1996_n
4,./dataset/test/img/01라0553.jpg,2006_n
...,...,...
58,./dataset/test/img/전북32바7467.jpg,echo
59,./dataset/test/img/전북32바7553.jpg,echo
60,./dataset/test/img/전북32바7595.jpg,echo
61,./dataset/test/img/전북82배1120.JPG,echo
