In [3]:
import os
import torch
import torch.nn as nn
import json
from torchvision import models, transforms
from PIL import Image
from tqdm.notebook import tqdm

모델을 새로 불러온 뒤 우리 데이터셋에 맞춰 수정

In [4]:
vgg16_model = models.vgg16(weights='DEFAULT')
num_features = vgg16_model.classifier[6].in_features
vgg16_model.classifier[6] = nn.Linear(num_features, 300)

우리 데이터셋에 대해 학습한 weight 불러오기

In [5]:
vgg16_model.load_state_dict(torch.load('vgg16_classifier_4.pth'))

<All keys matched successfully>

information[0]에는 0번 labes의 정보가, information[1]에는 1번 labes의 정보가, ...

In [6]:
label_folder = "label"
information = []

for filename in os.listdir(label_folder):
    if filename.endswith(".json"):
        file_path = os.path.join(label_folder, filename)
        with open(file_path, 'r', encoding='utf-8-sig') as file:
            json_content = json.load(file)
            information.append(json_content)

In [7]:
sampledict = information[0]['images'][0]
sampledict


{'file_name': 'K-029818_0_2_0_0_75_000_200.png',
 'width': 976,
 'height': 1280,
 'imgfile': 'K-029818_0_2_0_0_75_000_200.png',
 'drug_N': 'K-029818',
 'drug_S': '정상알약',
 'back_color': '연회색 배경',
 'drug_dir': '앞면',
 'light_color': '전구색',
 'camera_la': 75,
 'camera_lo': 0,
 'size': 200,
 'item_seq': '201206159',
 'leng_long': '13.2',
 'leng_short': '7.4',
 'thick': '4.3',
 'id': 1,
 'dl_idx': '29817',
 'dl_mapping_code': 'K-029818',
 'dl_name': '지소렌정',
 'dl_name_en': 'G-Soren Tab.',
 'img_key': 'http://connectdi.com/design/img/drug/147426622886300027.jpg',
 'dl_material': '애엽이소프로판올연조엑스',
 'dl_material_en': 'Artemisia Herb Isopropanol Soft Ext.(20→1)',
 'dl_custom_shape': '정제, 저작정',
 'dl_company': '지엘팜텍(주)',
 'dl_company_en': 'Glpharmtech',
 'di_company_mf': '풍림무약',
 'di_company_mf_en': 'Richwood Trading Company',
 'di_item_permit_date': '20120717',
 'di_class_no': '[02320]소화성궤양용제',
 'di_etc_otc_code': '전문의약품',
 'di_edi_code': '624200010',
 'chart': '녹색의 타원형 필름코팅정제',
 'drug_shape': '타원형',

In [34]:
#필요한것들
infolist = ['drug_N',
'drug_S',
'drug_dir',
'light_color',
'dl_name',
#'dl_name_en',
#'img_key',
'dl_material',
#'dl_material_en',
'dl_custom_shape',
'dl_company',
#'dl_company_en',
'di_company_mf',
'di_item_permit_date',
'di_class_no',
'di_etc_otc_code',
'di_edi_code',
'chart',
#'drug_shape',
#'form_code_name',
]

In [35]:
[sampledict.get(info) for info in infolist]

['K-029818',
 '정상알약',
 '앞면',
 '전구색',
 '지소렌정',
 '애엽이소프로판올연조엑스',
 '정제, 저작정',
 '지엘팜텍(주)',
 '풍림무약',
 '20120717',
 '[02320]소화성궤양용제',
 '전문의약품',
 '624200010',
 '녹색의 타원형 필름코팅정제']

# 어플리케이션 구현

In [10]:
# import torch
# import torch.nn as nn
# from torchvision import models, transforms
# from PIL import Image
# from tqdm.notebook import tqdm

import gradio as gr


In [36]:
def infer_vgg16(weight_path, image_tensor):
    #### model load ####
    vgg16_model = models.vgg16(weights='DEFAULT')
    num_features = vgg16_model.classifier[6].in_features
    vgg16_model.classifier[6] = nn.Linear(num_features, 300)
    vgg16_model.load_state_dict(torch.load(weight_path))
    model = vgg16_model

    #### inference ####
    model.eval()
    with torch.inference_mode():
        output = model(image_tensor)
    return output

def gradio_interface(weight_path, image):
    #### image to tensor ####
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image = Image.fromarray(image, 'RGB')
    image = transform(image).unsqueeze(0)

    #### inference : out_idx(int) ####
    vggout = infer_vgg16(weight_path=weight_path, image_tensor=image)
    out_idx = int(vggout.argmax())

    #### output data from information ####
    out_dict = information[out_idx]['images'][0]
    output = [out_dict.get(info) for info in infolist]
    preview_image = Image.open(requests.get(out_dict.get('img_key'), stream=True).raw)
    return output, preview_image


weightlist = [
                "./vgg16_classifier_4.pth",
            ]
demo = gr.Interface(gradio_interface,
                    [
                        gr.Dropdown(weightlist,label="weight path",info="test dropdown"),
                        gr.Image()
                    ],
                    outputs=['text',gr.Image()],
                    allow_flagging='never'
                    )

### test code

In [12]:
# image_path = './testimage001.jpg'
# image = Image.open(image_path).convert('RGB')
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])
# image_tensor = transform(image).unsqueeze(0)
# print(image_tensor.shape)
# output = infer_vgg16(weight_path='./vgg16_classifier_4.pth', image_tensor=image_tensor)
# int(output.argmax())

# gradio

In [37]:
demo.close()
demo.launch()

Running on local URL:  http://127.0.0.1:7865

To create a public link, set `share=True` in `launch()`.


