In [1]:
import numpy as np
import torch
from vit_pytorch import ViT

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")

In [2]:
import pandas as pd
from huggingface_hub import hf_hub_download
import joblib

drug_name_encoder = joblib.load(
    hf_hub_download("brainer/pill-identifier", "Condensed-Co-Graph-And-Size-Graph/drug_name_encoder.pkl")
)

id2label = pd.DataFrame(drug_name_encoder.classes_).to_dict()[0]
label2id = {v: k for k, v in id2label.items()}
id2label, label2id

({0: '195900001',
  1: '195900032',
  2: '195900043',
  3: '196000001',
  4: '196000003',
  5: '196000008',
  6: '196000011',
  7: '196200002',
  8: '196200032',
  9: '196200034',
  10: '196200043',
  11: '196200046',
  12: '196300001',
  13: '196300019',
  14: '196300064',
  15: '196400005',
  16: '196400037',
  17: '196400046',
  18: '196400099',
  19: '196500004',
  20: '196500051',
  21: '196600007',
  22: '196600011',
  23: '196600012',
  24: '196700049',
  25: '196700060',
  26: '196800048',
  27: '196800078',
  28: '196900043',
  29: '197000005',
  30: '197000037',
  31: '197000040',
  32: '197000049',
  33: '197000050',
  34: '197000053',
  35: '197000068',
  36: '197000076',
  37: '197000079',
  38: '197000102',
  39: '197000104',
  40: '197000196',
  41: '197000208',
  42: '197000211',
  43: '197000212',
  44: '197000215',
  45: '197100015',
  46: '197100073',
  47: '197100081',
  48: '197100088',
  49: '197100097',
  50: '197100165',
  51: '197200016',
  52: '197200084',
  5

In [3]:
model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=55808,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048
)

model.load_state_dict(torch.load("./model/pretrained-net.pt", map_location=device))
model

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=512, out_features=1024, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
      

In [4]:
model(torch.rand(1, 3, 256, 256))

tensor([[ 0.5097, -0.2324, -0.8648,  ..., -0.3714,  0.4941,  0.2402]],
       grad_fn=<AddmmBackward0>)

# Fine tune with MLP
## Change last layer

In [5]:
import torch

for param in model.parameters():
    param.requires_grad = False

last_layer = torch.nn.Sequential(torch.nn.Linear(1024, out_features=len(label2id)), torch.nn.ReLU())
print(last_layer)

model.mlp_head = last_layer
model.mlp_head.requires_grad_(True)
model

Sequential(
  (0): Linear(in_features=1024, out_features=25578, bias=True)
  (1): ReLU()
)


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=512, out_features=1024, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
      

# Fine tune

In [6]:
# label2id keys to int

for key in list(label2id.keys()):
    label2id[int(key)] = label2id.pop(key)

label2id

{195900001: 0,
 195900032: 1,
 195900043: 2,
 196000001: 3,
 196000003: 4,
 196000008: 5,
 196000011: 6,
 196200002: 7,
 196200032: 8,
 196200034: 9,
 196200043: 10,
 196200046: 11,
 196300001: 12,
 196300019: 13,
 196300064: 14,
 196400005: 15,
 196400037: 16,
 196400046: 17,
 196400099: 18,
 196500004: 19,
 196500051: 20,
 196600007: 21,
 196600011: 22,
 196600012: 23,
 196700049: 24,
 196700060: 25,
 196800048: 26,
 196800078: 27,
 196900043: 28,
 197000005: 29,
 197000037: 30,
 197000040: 31,
 197000049: 32,
 197000050: 33,
 197000053: 34,
 197000068: 35,
 197000076: 36,
 197000079: 37,
 197000102: 38,
 197000104: 39,
 197000196: 40,
 197000208: 41,
 197000211: 42,
 197000212: 43,
 197000215: 44,
 197100015: 45,
 197100073: 46,
 197100081: 47,
 197100088: 48,
 197100097: 49,
 197100165: 50,
 197200016: 51,
 197200084: 52,
 197200091: 53,
 197200160: 54,
 197200226: 55,
 197200376: 56,
 197200484: 57,
 197300013: 58,
 197300021: 59,
 197300084: 60,
 197300096: 61,
 197300104: 62,
 1

In [7]:
from datasets import load_dataset

dataset = load_dataset("brainer/drug_info", streaming=True)['train']
class_label = dataset.info.features['label']

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

In [8]:
# dataset['image']의 투명 하지 않은 부분을 bbox로 인식
# bbox의 좌표는 (x_min, y_min, x_max, y_max)로 표현
def get_bbox(image):
    """
    Find the bounding box of the non-transparent part of a PNG image using optimized method.

    :param image: PIL Image object.
    :return: Tuple (left, upper, right, lower) representing the bounding box.
    """
    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    # Convert the image to a numpy array
    data = np.array(image)

    # Extract the alpha channel and find where it's non-zero
    alpha = data[:, :, 3]
    non_zero_indices = np.argwhere(alpha != 0)

    if non_zero_indices.size == 0:
        return None

    # Find the bounding box coordinates
    min_y, min_x = non_zero_indices.min(axis=0)
    max_y, max_x = non_zero_indices.max(axis=0)

    return [(min_x, min_y, max_x - min_x, max_y - min_y)]


def crop_image(image):
    x, y, w, h = get_bbox(image)[0]
    return image.crop((x, y, x + w, y + h))

In [9]:
from tqdm.notebook import tqdm

model.train()
model = model.to(device)
opt = torch.optim.NAdam(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()


def collate_fn(batch):
    batch_image_list = []
    batch_label_list = []
    for _, data in enumerate(batch):
        image = crop_image(data["image"])
        image = image.convert("RGB")
        image = image.resize((256, 256))
        image = np.asarray(image).reshape(3, 256, 256)
        batch_image_list.append(image)
        batch_label_list.append(data["label"])
    return {'image': batch_image_list, 'label': batch_label_list}


batch_size = 16

torch_dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

for i, data in tqdm(enumerate(torch_dataset)):
    # print(f"data: {data}")
    predict = model(torch.tensor(np.asarray(data['image'])).to(device))

    # for i in data['label']:
    #     print(f"label: {i}")
    #     print(f"class: {class_label.int2str(i)}")

    truth_label = [label2id[int(label)] for label in [class_label.int2str(i) for i in data['label']]]
    truth = torch.zeros(batch_size, len(label2id))

    for i, label in enumerate(truth_label):
        truth[i, label] = 1

    pred_label = []
    for label in torch.argmax(predict, dim=1):
        if len(class_label.names) > label:
            pred_label.append(label2id[int(class_label.int2str(label))])
        else:
            pred_label.append(label2id[int(class_label.int2str(0))])

    pred = torch.zeros(batch_size, len(label2id))
    for i, label in enumerate(pred_label):
        pred[i, label] = 1

    # print(f"pred: {pred}, truth: {truth}")
    opt.zero_grad()
    loss = loss_fn(pred.clone().detach().requires_grad_(True), truth.clone().detach().requires_grad_(True))
    loss.backward()
    opt.step()
    if loss.item() < 0.1:
        break
    if i % 100 == 0:
        print(f"loss: {loss.item()}")

0it [00:00, ?it/s]

In [15]:
model_path = "./model/dino-finetuned-net.pt"

torch.save(model.state_dict(), model_path)

from huggingface_hub import HfApi

api = HfApi()

api.upload_file(
    path_or_fileobj=model_path,
    path_in_repo=f"Feature_Extractor/dino-finetuned-net.pt",
    repo_id="brainer/PGPNet")

dino-finetuned-net.pt:   0%|          | 0.00/269M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/brainer/PGPNet/commit/815d4a218197df2ae69f58e04beee3a5bcc94b13', commit_message='Upload Feature_Extractor/dino-finetuned-net.pt with huggingface_hub', commit_description='', oid='815d4a218197df2ae69f58e04beee3a5bcc94b13', pr_url=None, pr_revision=None, pr_num=None)

In [10]:
batch_size = 16
sample_label = torch.zeros(batch_size, len(label2id))
sample_label[0, 0] = 1
sample_label

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [11]:
sample = next(iter(torch_dataset))
sample

{'image': [array([[[  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          ...,
          [151, 135, 152, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ..., 145, 129, 137]],
  
         [[130, 115, 122, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...,  98,  84,  76],
          ...,
          [ 46,  76,  46, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,  13],
          [ 10,  10, 101, ...,  44,  73,  44]],
  
         [[ 44,  73,  44, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,  45, ...,  44,  71,  42],
          ...,
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...,   0,   0,   0]]], dtype=uint8),
  array([[[  0,   0,   0, ...,   0,   0,   0],
          [  0,   0,   0, ...

In [12]:
# [label2id[int(label)] for label in [class_label.int2str(i) for i in sample['label']]]


# predict_label = [label2id[int(label)] for label in [class_label.int2str(i) for i in torch.argmax(predict, dim=1)]]

pred = model(torch.rand(16, 3, 256, 256).to(device))
[class_label.int2str(int(i)) for i in torch.argmax(pred, dim=1)]

ValueError: Invalid integer class label 21905

# Predict

In [None]:
from PIL import Image
import requests
from io import BytesIO

img_url = "https://img.freepik.com/free-photo/minimal-medicinal-pills-arrangement_23-2148892392.jpg?size=626&ext=jpg&ga=GA1.1.1412446893.1704499200&semt=sph"

response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = img.resize(size=(256, 256))
model = model.to(device)
result = model(torch.from_numpy(np.asarray(img)).permute(2, 0, 1).unsqueeze(0).to(device))

In [None]:
id2label[int(torch.argmax(result))]